In [1]:
import pandas as pd
import numpy as np

import torch
from transformers import BertTokenizer, BertConfig, BertModel

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
from dataset import RelationsDS

data = RelationsDS(root='./data').to(device=device)
data[0]

Data(edge_index=[2, 111276], edge_type=[111276], token_ids=[91294, 138], token_mask=[91294, 138], token_type_ids=[91294, 138])

# BERT encoder

In [16]:
from torch.utils.data import DataLoader

g = data[0].to(device)
tokens = torch.stack([g.token_ids, g.token_mask, g.token_type_ids], dim=1)
tokens.shape

torch.Size([91294, 3, 138])

In [11]:
model = BertModel.from_pretrained('bert-base-uncased').to(device)

In [36]:
91294/2, 45647/7

(45647.0, 6521.0)

In [42]:
batch_size = 1024
token_loader = DataLoader(tokens, batch_size=batch_size)

encoding_size = 768
x = torch.zeros((tokens.shape[0], encoding_size), dtype=torch.float)

with torch.no_grad():
    for batch_num, token_batch in enumerate(token_loader):
        if batch_num % 10*batch_size == 0:
            print(f'batch {batch_num}/{np.ceil(tokens.shape[0]/batch_size)}')
        batch_encoding = model(token_batch[:,0,:],
                               token_batch[:,1,:],
                               token_batch[:,2,:]).last_hidden_state[:,0,:]
        batch_start = batch_num * batch_size
        batch_end = (batch_num+1)*batch_size
        if batch_end > tokens.shape[0]-1:
            batch_end = tokens.shape[0]-1
        x[batch_start:batch_end,:] = batch_encoding

batch 0/90.0
batch 10/90.0


KeyboardInterrupt: 

In [7]:
with torch.no_grad():
    idx = 0
    id = g.token_ids[idx].unsqueeze(0)
    mask = g.token_mask[idx].unsqueeze(0)
    type_ids = g.token_type_ids[idx].unsqueeze(0)
    print(id.shape, mask.shape, type_ids.shape)

    encoding = model(id, mask, type_ids)

torch.Size([1, 138]) torch.Size([1, 138]) torch.Size([1, 138])


In [5]:
with torch.no_grad():
    num = 10
    id = g.token_ids[:num]
    mask = g.token_mask[:num]
    type_ids = g.token_type_ids[:num]
    print(id.shape, mask.shape, type_ids.shape)

    encoding = model(id, mask, type_ids)

torch.Size([10, 138]) torch.Size([10, 138]) torch.Size([10, 138])


In [10]:
lhs = encoding.last_hidden_state
lhs = lhs[:, 0, :]
lhs = lhs.flatten(start_dim=1)
print(lhs.shape, type(lhs))
print(lhs.mean(), lhs.std())

torch.Size([10, 768]) <class 'torch.Tensor'>
tensor(-0.0102, device='cuda:0') tensor(0.5358, device='cuda:0')


# Model

adapted from [Online Link Prediction with Graph Neural Networks](https://medium.com/stanford-cs224w/online-link-prediction-with-graph-neural-networks-46c1054f2aa4)

In [3]:
from model import Model, LinkPredictor

In [4]:
g = data[0].to(device)
tokens = torch.stack([g.token_ids, g.token_mask, g.token_type_ids], dim=1)

model = Model(input_dim=768,
              hidden_dim=256,
              output_dim=256,
              num_layers=2     ).to(device)
link_pred = LinkPredictor(input_dim=256, 
                          hidden_dim=128,
                          output_dim=1,
                          num_layers=3  ).to(device)

In [9]:
with torch.no_grad():
    node_embeddings = model.encode_inputs(tokens) 

In [18]:
torch.save(node_embeddings, 'data/input_encodings.pt')

In [5]:
node_embeddings = model.load_input_encodings().to(device)
emb = model(node_embeddings, g.edge_index)

In [10]:
pred = link_pred(emb[0], emb[1])
pred


tensor([0.5217], device='cuda:0', grad_fn=<SigmoidBackward0>)