In [None]:
import pickle
import os
import torch
import random
import numpy as np
from tqdm.notebook import tqdm

from codes.query_solver import GeometricSolver
from codes.triplets import TripletsEngine

EMBEDDING_DIM = 512
# MODEL_PATH = "/home/cc/phd/KGEmbeddings/models/TransE_FB15k_0/"
MODEL_PATH = "/home/cc/phd/KGEmbeddings/models/RotatE_FB15k_0/"
MODEL_NAME = "rotate"
DICTS_DIR = "/home/cc/phd/KGEmbeddings/data/FB15k/"

with open('queries/FBK15k/queries-medium.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)

queries = loaded_dict['queries']
results = loaded_dict['results']

kg = TripletsEngine(os.path.join(DICTS_DIR), from_splits=True)
qs = GeometricSolver(MODEL_PATH, MODEL_NAME, EMBEDDING_DIM, h2t=kg.h2t, k_neighbors=50, k_results=25, device='cuda')

In [None]:
# Complex query

# idx = np.random.randint(len(queries))

# qs.set_k(k_neighbors=100, k_results=15)
# res = qs.execute_query(queries[idx], proj_mode="inter", agg_mode="union", trues=results[idx])

# founds = [ f for f in res if f in results[idx][1] ]
# print(f"Query ({idx}): {queries[idx]}")
# print(f"Found: {len(founds)} on {len(results[idx][1])} targets from {len(res)} results")

# Simple query
# simple_query = random.choice(list(kg.h2t.keys()))

# qs.set_k(k_neighbors=10, k_results=25)
# res = qs.execute_query([simple_query], proj_mode="inter", agg_mode="union")

# founds = [ f for f in res if f in kg.h2t[simple_query]]
# print(f"Query: {simple_query}")
# print(f"Found: {len(founds)} on {len(kg.h2t[simple_query])} targets from {len(res)} results")

# len(qs.get_metrics()['mrr'])

In [None]:
qs.set_k(k_neighbors=50)
metrics = {
            "mrr": [],
            "hits3": [],
            "hits5": [],
            "hits10": [],
        }

for key in tqdm(kg.h2t.keys()):

    res = qs.execute_query([key], proj_mode="inter", agg_mode="union")
    founds = [ f for f in res if f in kg.h2t[key]]

    ids = torch.tensor(founds)

    for t in kg.h2t[key]:
        ranking = (ids == t)
        if ranking.sum():
            ranking = ranking.nonzero(as_tuple=True)[0]+1
            metrics['mrr'].append(1.0 / ranking.item())
            metrics['hits3'].append(1.0 if ranking <= 3 else 0.0)
            metrics['hits5'].append(1.0 if ranking <= 5 else 0.0)
            metrics['hits10'].append(1.0 if ranking <= 10 else 0.0)

# print(f"Average Recall over {len(queries)} complex queries (flat method): {np.mean(metrics['recall'])}")
print(f"Average MRR over {len(kg.h2t.keys())} simple queries: {np.mean(metrics['mrr'])}")
print(f"Average Hits@K over {len(kg.h2t.keys())} simple queries: 3: {np.mean(metrics['hits3'])}, 5: {np.mean(metrics['hits5'])}, 10: {np.mean(metrics['hits10'])}")

In [None]:
def recall_at_k(pred, true, k):
    if len(true) == 0:
        return 1.0
    
    if k > 0:
        pred_k = pred[:max(k, len(true)+10)]
    else:
        pred_k = pred

    hits = sum([1 for p in pred_k if p in true])
    return hits / len(true)

qs.set_k(k_neighbors=50, k_results=25)
recalls = {
    "recall": [],
    "recall5": [],
    "recall10": [],
    "recall25": [],
    "recall50": [],
}

for query, result in tqdm(zip(queries, results), total=len(queries)):

    res = qs.execute_query(query, proj_mode="inter", agg_mode="union", trues=result)

    if len(res) > 0:
        for k in [5, 10, 25, 50]:
            recalls[f"recall{k}"].append(recall_at_k(res, result[-1], k))

            recalls["recall"].append(recall_at_k(res, result[-1], 0))

metrics = qs.get_metrics()

print(f"Average Recall over {len(queries)} complex queries (2p1): {np.mean(recalls['recall'])}")
print(f"Average MRR over {len(queries)} complex queries (2pi): {np.mean(metrics['mrr'])}")
print(f"Average Recall@K over {len(queries)} complex queries (2p1): 5: {np.mean(recalls['recall5'])}, 10: {np.mean(recalls['recall10'])}, \
25: {np.mean(recalls['recall25'])}, 50: {np.mean(recalls['recall50'])}")
print(f"Average Hits@K over {len(queries)} complex queries (2pi): 1: {np.mean(metrics['hits1'])}, 3: {np.mean(metrics['hits3'])}, \
5: {np.mean(metrics['hits5'])}, 10: {np.mean(metrics['hits10'])}, 25: {np.mean(metrics['hits25'])}")

In [None]:
# Average Recall over 4067 complex queries (2p1): 0.6109156931334205
# Average MRR over 174097 complex queries (2pi): 0.23883582081424812
# Average Hits@K over 174097 complex queries (2pi): 3: 0.2702445393466257, 5: 0.37974775362910174, 10: 0.5578162826755162