In [1]:
from torch import cuda
from torch.optim import Adam 
import torch

import pandas as pd
import torchkge
from torchkge.models import TransEModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils import MarginLoss, DataLoader
# from torchkge.utils.datasets import load_fb15k237
from torchkge.utils.datasets import load_wikidatasets

from tqdm.autonotebook import tqdm

savedobj_path = '../saved_obj/'

In [2]:
humans_kge_train, humans_kge_val, humans_kge_test  = load_wikidatasets('humans', limit_=50, data_home='C:/Users/Mathew/torchkge_data/')

In [3]:
# Define some hyper-parameters for training
emb_dim = 650
lr = 0.0004
n_epochs = 10
b_size = 32
margin = 0.5

In [10]:
saved_50_model = model = TransEModel(emb_dim, humans_kge_train.n_ent, humans_kge_train.n_rel, dissimilarity_type='L2')
saved_50_model.load_state_dict(torch.load(savedobj_path+'humans_transe_sdict.pt'))

<All keys matched successfully>

In [4]:
# Define the model and criterion
model = TransEModel(emb_dim, humans_kge_train.n_ent, humans_kge_train.n_rel, dissimilarity_type='L2')
criterion = MarginLoss(margin)

In [5]:
# Move everything to CUDA if available
if cuda.is_available():
    cuda.empty_cache()
    model.cuda()
    criterion.cuda()

In [50]:
# Define the torch optimizer to be used
optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)

sampler = BernoulliNegativeSampler(humans_kge_train)
dataloader = DataLoader(humans_kge_train, batch_size=b_size, use_cuda='all')

iterator = tqdm(range(n_epochs), unit='epoch')
for epoch in iterator:
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        h, t, r = batch[0], batch[1], batch[2]
        n_h, n_t = sampler.corrupt_batch(h, t, r)

        optimizer.zero_grad()

        # forward + backward + optimize
        pos, neg = model(h, t, n_h, n_t, r)
        loss = criterion(pos, neg)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    iterator.set_description(
        'Epoch {} | mean loss: {:.5f}'.format(epoch + 1,
                                              running_loss / len(dataloader)))

model.normalize_parameters()

Epoch 10 | mean loss: 0.33868: 100%|██████████| 10/10 [23:59<00:00, 143.96s/epoch]


In [63]:
torch.save(model.state_dict(), savedobj_path+'humans_transe_sdict.pt')

In [11]:
from torchkge.evaluation import LinkPredictionEvaluator

# Link prediction evaluation on validation set.
evaluator = LinkPredictionEvaluator(saved_50_model, humans_kge_val)
evaluator.evaluate(b_size=32, k_max=10)
evaluator.print_results()



Link prediction evaluation: 100%|██████████| 238/238 [11:41<00:00,  2.95s/batch]Hit@10 : 0.241 		 Filt. Hit@10 : 0.575
Mean Rank : 253 	 Filt. Mean Rank : 233
MRR : 0.098 		 Filt. MRR : 0.264



In [17]:
# Define the torch optimizer to be used
corrupt_prop = 0.1
corrupt_num = round(corrupt_prop*b_size)
optimizer = Adam(model.parameters(), lr=lr, weight_decay=1e-5)

sampler = BernoulliNegativeSampler(humans_kge_train)
dataloader = DataLoader(humans_kge_train, batch_size=b_size, use_cuda='all')

iterator = tqdm(range(2), unit='epoch')
for epoch in iterator:
    running_loss = 0.0
    for i, batch in enumerate(dataloader):
        
        # print(f"batch0_len: {len(batch[0])}")
        h0, t0, r0 = batch[0], batch[1], batch[2]
        hf, tf = sampler.corrupt_batch(h0[-10:], t0[-10:], r0[-10:])
        print(h0[:22].size())
        print(hf.size())
        h = torch.cat((h0[:22], hf))
        t = torch.cat((t0[:22], tf))
        print(h.size())
        # print(h, t, r)
        n_h, n_t = sampler.corrupt_batch(h, t, r0)
        # h2 = torch.cat((h, n_h))
        # print(h2)
        

        optimizer.zero_grad()

        # forward + backward + optimize
        pos, neg = model(h, t, n_h, n_t, r0)
        loss = criterion(pos, neg)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    iterator.set_description(
        'Epoch {} | mean loss: {:.5f}'.format(epoch + 1,
                                              running_loss / len(dataloader)))

model.normalize_parameters()

0%|          | 0/2 [00:00<?, ?epoch/s]torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size([32])
torch.Size([22])
torch.Size([10])
torch.Size

KeyboardInterrupt: 

In [18]:
round(0.1*b_size)

3

In [26]:
humans_kge_train[2]

(2, 2562, 18)

In [7]:
xx = torch.load(savedobj_path+'humans_transe_sdict.pt')

In [14]:
xx['rel_emb.weight'].size()

torch.Size([50, 650])