In [1]:
from ontolearn.knowledge_base import KnowledgeBase
from utils.dataloader import NCESDataLoaderInference
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pandas as pd

In [2]:
from argparse import Namespace
import json
import torch, pandas as pd
with open("settings.json") as setting:
    args = json.load(setting)
args = Namespace(**args)

In [3]:
import numpy as np, time
from collections import defaultdict

In [4]:
def before_pad(arg):
    arg_temp = []
    for atm in arg:
        if atm == 'PAD':
            break
        arg_temp.append(atm)
    return arg_temp

In [5]:
def map_to_token(model, idx_array):
    return model.inv_vocab[idx_array]

In [6]:
def get_data(kb, embeddings, kwargs):
    data_test_path = f"datasets/{kb}/Test_data/Data.json"
    with open(data_test_path, "r") as file:
        data_test = json.load(file)
    data_test = list(data_test.items())
    test_dataset = NCESDataLoaderInference(data_test, embeddings, kwargs)
    num_examples = test_dataset.num_examples
    def collate_batch(batch):
        pos_emb_list = []
        neg_emb_list = []
        for pos_emb, neg_emb in batch:
            if pos_emb.ndim != 2:
                pos_emb = pos_emb.reshape(1, -1)
            if neg_emb.ndim != 2:
                neg_emb = neg_emb.reshape(1, -1)
            pos_emb_list.append(pos_emb)
            neg_emb_list.append(neg_emb)
        pos_emb_list[0] = F.pad(pos_emb_list[0], (0, 0, 0, num_examples - pos_emb_list[0].shape[0]), "constant", 0)
        pos_emb_list = pad_sequence(pos_emb_list, batch_first=True, padding_value=0)
        neg_emb_list[0] = F.pad(neg_emb_list[0], (0, 0, 0, num_examples - neg_emb_list[0].shape[0]), "constant", 0)
        neg_emb_list = pad_sequence(neg_emb_list, batch_first=True, padding_value=0)
        return pos_emb_list, neg_emb_list
    print("Number of learning problems: ", len(test_dataset))
    test_dataloader = DataLoader(test_dataset, batch_size=kwargs.batch_size, num_workers=kwargs.num_workers, collate_fn=collate_batch, shuffle=False)
    return test_dataloader

In [7]:
def get_ensemble_prediction(models, x1, x2):
    for i,model in enumerate(models):
        model.eval()
        if i == 0:
            _, scores = model(x1, x2)
        else:
            _, sc = model(x1, x2)
            scores = scores + sc
    scores = scores/len(models)
    prediction = model.inv_vocab[scores.argmax(1)]
    return prediction, scores

In [8]:
kb = "carcinogenesis"
embeddings = pd.read_csv(f"datasets/{kb}/Embeddings/ConEx_entity_embeddings.csv").set_index('Unnamed: 0')
#args.batch_size = 4
args.knowledge_base_path = "datasets/"+f"{kb}/{kb}.owl"
dataloader = get_data(kb, embeddings, args)



Number of learning problems:  111


In [9]:
pos_emb, neg_emb = next(iter(dataloader))

In [10]:
def predict_single(model_name):
    model = torch.load(f"datasets/{kb}/Model_weights/ConEx_{model_name}.pt", map_location=torch.device('cpu'))
    print(f"Predictions with {model_name}")
    return model(pos_emb, neg_emb)

## View some predictions

### Single model

In [11]:
predictions, scores = predict_single("SetTransformer")

Predictions with SetTransformer


In [12]:
predictions[0]

array(['Bromine', ' ', '⊔', ' ', 'Sulfide', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [13]:
predictions[1]

array(['Di227', ' ', '⊔', ' ', '(', '∃', ' ', 'hasStructure', '.',
       'Amine', ')', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

### Ensemble

In [14]:
def predict_ensemble(model_names):
    models = [torch.load(f"datasets/{kb}/Model_weights/ConEx_{name}.pt", map_location=torch.device('cpu'))\
              for name in model_names]
    print("Predictions with Ensemble model")
    return get_ensemble_prediction(models, pos_emb, neg_emb)

In [15]:
predictions_ens, scores = predict_ensemble(["SetTransformer", "GRU"])

Predictions with Ensemble model


In [16]:
predictions_ens[0]

array(['Bromine', ' ', '⊔', ' ', 'Sulfide', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [17]:
predictions_ens[1]

array(['Halide10', ' ', '⊔', ' ', '(', '∃', ' ', 'hasStructure', '.',
       'Amine', ')', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [18]:
predictions_ens, scores = predict_ensemble(["LSTM", "GRU"])

Predictions with Ensemble model


In [19]:
predictions_ens[0]

array(['Bromine', ' ', '⊔', ' ', 'Sulfide', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [20]:
predictions_ens[1]

array(['Halide10', ' ', '⊔', ' ', '(', '∃', ' ', 'hasStructure', '.',
       'Amine', ')', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [21]:
predictions_ens, scores = predict_ensemble(["LSTM", "SetTransformer"])

Predictions with Ensemble model


In [22]:
predictions_ens[0]

array(['Bromine', ' ', '⊔', ' ', 'Sulfide', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [23]:
predictions_ens[1]

array(['Di227', ' ', '⊔', ' ', '(', '∃', ' ', 'hasStructure', '.',
       'Amine', ')', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [24]:
predictions_ens, scores = predict_ensemble(["LSTM", "GRU", "SetTransformer"])

Predictions with Ensemble model


In [25]:
predictions_ens[0]

array(['Bromine', ' ', '⊔', ' ', 'Sulfide', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)

In [26]:
predictions_ens[1]

array(['Halide10', ' ', '⊔', ' ', '(', '∃', ' ', 'hasStructure', '.',
       'Amine', ')', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD', 'PAD',
       'PAD', 'PAD', 'PAD'], dtype=object)