In [38]:
from models.DistMult import DistMult
from models.Complex import Complex
from models.ConvE import ConvE, ConvE_args

from utils.loaders import load_data, get_onehots
from utils.evaluation_metrics import SRR, auprc_auroc_ap

import torch
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

In [48]:
def main(model, optimiser, train_data, val_data, epochs, batches):
    for epoch in range(epochs):
            #training stage 
            model.train()
            objects, subjects, relationships  = load_data(train_data, batches)

            for index in range(batches):

                obj = torch.LongTensor(objects[index])
                rel = torch.LongTensor(relationships[index])
                subj = torch.squeeze(torch.LongTensor(subjects[index]))

                optimiser.zero_grad()
                pred = model.forward(e1 = obj, rel = rel)
                loss = model.loss(pred, subj)
                loss.backward()
                optimiser.step()


            #evaluation stage
            model.eval()
            objects, subjects, relationships  = load_data(val_data, batch_number=1)
            total_sum_reciporical_rank = torch.zeros(1)

            # for index in range(batches):
                # obj = torch.LongTensor(objects[index])
                # rel = torch.LongTensor(relationships[index])
                # targets = torch.LongTensor(subjects[index])
                # predictions = model.forward(e1 = obj, rel = rel)
                # srr = SRR(predictions, targets) 
                # total_sum_reciporical_rank = total_sum_reciporical_rank + srr

            obj = torch.squeeze(torch.LongTensor(objects)).unsqueeze(1)
            rel = torch.squeeze(torch.LongTensor(relationships)).unsqueeze(1)
            targets = torch.squeeze(torch.LongTensor(subjects)).unsqueeze(1)

            predictions = model.forward(e1 = obj, rel = rel)
            srr = SRR(predictions, targets) 
            total_sum_reciporical_rank = total_sum_reciporical_rank + srr
          
            print('mean reciporical rank is...', total_sum_reciporical_rank/ len(val_data))
            # print(total_sum_reciporical_rank / len(test_data))

            one_hots = get_onehots(targets, entities)
            auprc, auroc, ap = auprc_auroc_ap(one_hots, predictions)

            print('auroc is...', auroc)
            print('auprc is...', auprc)
            print('ap@50 is...', ap)

    return(total_sum_reciporical_rank, auroc, auprc, ap)


In [40]:
from utils.path_manage import get_files

data, lookup, ASD_dictionary, BCE_dictionary = get_files()
entities = int(len(lookup)/2)

In [41]:
batches = 5
epochs = 10
x = shuffle(data)
test_data = x[:100] #just limit data to the first n rows

In [42]:
X_train, X_test = train_test_split(test_data, test_size=0.1, random_state=1)
X_train, X_val = train_test_split(X_train, test_size=0.1111, random_state=1) 

In [43]:
args = ConvE_args()

In [44]:
# model = DistMult(num_entities = entities, embedding_dim=100, num_relations=4)
# model = Complex(num_entities = entities, embedding_dim=100, num_relations=4)
model = ConvE(args = args, embedding_dim=200, num_entities=entities, num_relations=4)
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)

70555 4


In [47]:
total_sum_reciporical_rank, auroc, auprc, ap = main(model=model, optimiser=optimiser, train_data=X_train, val_data=X_val, epochs=epochs, batches=batches)

mean reciporical rank is... tensor([0.0063])
auroc is... 0.7255230036567736
auprc is... 0.00017846182152123834
ap@50 is... 0.000325080969565733
mean reciporical rank is... tensor([0.0202])
auroc is... 0.7076012699492586
auprc is... 0.00030272511232883796
ap@50 is... 0.0004978417684235109
mean reciporical rank is... tensor([0.1029])
auroc is... 0.7103224480539729
auprc is... 0.0013815144285051393
ap@50 is... 0.0024002820517194934
mean reciporical rank is... tensor([0.1335])
auroc is... 0.7073590441364062
auprc is... 0.0073595181265391565
ap@50 is... 0.013968976540632293
mean reciporical rank is... tensor([0.0649])
auroc is... 0.7248542960002268
auprc is... 0.005333839211592712
ap@50 is... 0.009659618084023897
mean reciporical rank is... tensor([0.0549])
auroc is... 0.7348279332142756
auprc is... 0.003011246852842489
ap@50 is... 0.004805235038595428
mean reciporical rank is... tensor([0.0361])
auroc is... 0.7416432803242906
auprc is... 0.0019768342003107344
ap@50 is... 0.0029491657447225