In [3]:
import pandas as pd
import numpy as np

import torch
from transformers import BertTokenizer, BertConfig, BertModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
from dataset import RelationsDS

data = RelationsDS(root='./data').to(device=device)
data[0]

Data(edge_index=[2, 111276], edge_type=[111276], token_ids=[91294, 138], token_mask=[91294, 138], token_type_ids=[91294, 138])

# BERT encoder

In [5]:
g = data[0]
g

Data(edge_index=[2, 111276], edge_type=[111276], token_ids=[91294, 138], token_mask=[91294, 138], token_type_ids=[91294, 138])

In [4]:
model = BertModel.from_pretrained('bert-base-uncased').to(device)

In [7]:
with torch.no_grad():
    idx = 0
    id = g.token_ids[idx].unsqueeze(0)
    mask = g.token_mask[idx].unsqueeze(0)
    type_ids = g.token_type_ids[idx].unsqueeze(0)
    print(id.shape, mask.shape, type_ids.shape)

    encoding = model(id, mask, type_ids)

torch.Size([1, 138]) torch.Size([1, 138]) torch.Size([1, 138])


In [5]:
with torch.no_grad():
    num = 10
    id = g.token_ids[:num]
    mask = g.token_mask[:num]
    type_ids = g.token_type_ids[:num]
    print(id.shape, mask.shape, type_ids.shape)

    encoding = model(id, mask, type_ids)

torch.Size([10, 138]) torch.Size([10, 138]) torch.Size([10, 138])


In [6]:
lhs = encoding.last_hidden_state
lhs = lhs.flatten(start_dim=1)
print(lhs.shape, type(lhs))
print(lhs.mean(), lhs.std())

torch.Size([10, 105984]) <class 'torch.Tensor'>
tensor(-0.0086, device='cuda:0') tensor(0.4080, device='cuda:0')


# Model

adapted from [Online Link Prediction with Graph Neural Networks](https://medium.com/stanford-cs224w/online-link-prediction-with-graph-neural-networks-46c1054f2aa4)

In [6]:
from model import Model, LinkPredictor

In [7]:
with torch.no_grad():
    batch_size = 10 

    g = g.to(device)
    tokens = {
        'token_ids': g.token_ids[:batch_size],
        'token_mask': g.token_mask[:batch_size],
        'token_type_ids': g.token_type_ids[:batch_size]
    }

    model = Model(input_dim=105984,
                  hidden_dim=256,
                  output_dim=256,
                  num_layers=2     ).to(device)
    link_pred = LinkPredictor(input_dim=256, 
                              hidden_dim=128,
                              output_dim=1,
                              num_layers=3  ).to(device)
    
    node_embeddings = model(tokens, g.edge_index) 

/opt/conda/conda-bld/pytorch_1716905971214/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [15,0,0], thread: [96,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1716905971214/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [15,0,0], thread: [97,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1716905971214/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [15,0,0], thread: [98,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1716905971214/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [15,0,0], thread: [99,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
/opt/conda/conda-bld/pytorch_1716905971214/work/aten/src/ATen/native/cuda/Scatte

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
