In [None]:
import pickle
import os
import torch
import random
import numpy as np
from tqdm.notebook import tqdm

from codes.query_solver import GeometricSolver
from codes.triplets import TripletsEngine

# PATH = "/home/marco_dossena/PHD/KGEmbeddings/"
PATH = "/home/cc/phd/KGEmbeddings/"
EMBEDDING_DIM = 512
DATA = "umls"
# MODEL_PATH = "/home/cc/phd/KGEmbeddings/models/TransE_FB15k_0/"
# MODEL_PATH = "/home/cc/phd/KGEmbeddings/models/RotatE_FB15k_0/"
MODEL_PATH = f"{PATH}models/TransE_{DATA}_0"
MODEL_NAME = "transe"
# DICTS_DIR = "/home/cc/phd/KGEmbeddings/data/FB15k/"
DICTS_DIR = f"{PATH}data/{DATA}"
MODE = "tail-batch"  # head-batch or tail-batch

kg = TripletsEngine(os.path.join(DICTS_DIR), ext="txt" if DATA == "FB15k" else "csv", from_splits=True)
qs = GeometricSolver(MODEL_PATH, MODEL_NAME, EMBEDDING_DIM, h2t=kg.h2t, k_neighbors=50, k_results=25, device='cuda')

In [None]:
def recall_at_k(pred, true, k):
    if len(true) == 0:
        return 1.0
    
    if k > 0:
        pred_k = pred[:max(k, len(true)+10)]
    else:
        pred_k = pred

    hits = sum([1 for p in pred_k if p in true])
    return hits / len(true)

def map_at_k(pred, true, k):
    if len(true) == 0:
        return 1.0
    
    if k > 0:
        pred_k = pred[:k]
    else:
        pred_k = pred

    hits = sum([1 for p in pred_k if p in true])
    return hits / k

qs.set_k(k_neighbors=50, k_results=25)
recalls = {
    "recall1": [],
    "recall5": [],
    "recall10": [],
    "recall25": [],
    "recall50": [],
}

maps = {
    'MAP@1': [],
    'MAP@5': [],
    'MAP@10': [],
    'MAP@25': [],
    'MAP@50': [],
}

if MODE == 'head-batch':
    adj = kg.t2h
else:
    adj = kg.h2t

In [None]:
qs._reset_metrics()

for triplet in tqdm(kg.triplets[kg.train_set][:10000], desc="Evaluating triplets"):
    h, r, t = triplet
    if MODE == 'head-batch':
        to_remove = set(adj.get((t, r), []))
        to_remove.discard(h)
        query = (t, r)
        true = h
    else:
        to_remove = set(adj.get((h, r), []))
        to_remove.discard(t)
        query = (h, r)
        true = t

    pred = qs.execute_search_step(query, true, to_remove, mode=MODE)

    if len(pred) > 0:
        for k in [1, 5, 10, 25, 50]:
            recalls[f"recall{k}"].append(recall_at_k(pred, [true], k))
            maps[f'MAP@{k}'].append(map_at_k(pred, [true], k))

metrics = qs.get_metrics()


print(f"Average MRR over {len(kg.triplets)} triplets: {np.mean(metrics['mrr'])}")
print(f"Average Hits@K over {len(kg.triplets)} triplets: 1: {np.mean(metrics['hits1'])}, 3: {np.mean(metrics['hits3'])}, \
5: {np.mean(metrics['hits5'])}, 10: {np.mean(metrics['hits10'])}, 25: {np.mean(metrics['hits25'])}")
print(f"Average Recall@K over {len(kg.triplets)} triplets: 1: {np.mean(recalls['recall1'])}, 5: {np.mean(recalls['recall5'])}, 10: {np.mean(recalls['recall10'])}, \
25: {np.mean(recalls['recall25'])}, 50: {np.mean(recalls['recall50'])}")
print(f"Average MAP@K over {len(kg.triplets)} triplets: 1: {np.mean(maps['MAP@1'])}, 5: {np.mean(maps['MAP@5'])}, 10: {np.mean(maps['MAP@10'])}, \
25: {np.mean(maps['MAP@25'])}, 50: {np.mean(maps['MAP@50'])}")