In [None]:
import pickle
import os
import torch
import random
import numpy as np
from tqdm.notebook import tqdm

from codes.query_solver import GeometricSolver
from codes.triplets import TripletsEngine

EMBEDDING_DIM = 512
MODEL_PATH = "/home/cc/phd/KGEmbeddings/models/TransE_FB15k_0/"
DICTS_DIR = "/home/cc/phd/KGEmbeddings/data/FB15k/"

with open('queries-set2.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)

queries = loaded_dict['queries']
results = loaded_dict['results']

kg = TripletsEngine(os.path.join(DICTS_DIR), from_splits=True)
qs = GeometricSolver(MODEL_PATH, "transe", EMBEDDING_DIM, k_neighbors=50, k_results=25, device='cuda')

In [2]:
# # Complex query

# idx = np.random.randint(len(queries))

# qs.set_k(k_neighbors=100, k_results=15)
# res = qs.execute_query(queries[idx], proj_mode="inter", agg_mode="union")

# founds = [ f for f in res if f in results[idx][1] ]
# print(f"Query ({idx}): {queries[idx]}")
# print(f"Found: {len(founds)} on {len(results[idx][1])} targets from {len(res)} results")

# # Simple query
# simple_query = random.choice(list(kg.h2t.keys()))

# qs.set_k(k_neighbors=10, k_results=25)
# res = qs.execute_query([simple_query], proj_mode="inter", agg_mode="union")

# founds = [ f for f in res if f in kg.h2t[simple_query]]
# print(f"Query: {simple_query}")
# print(f"Found: {len(founds)} on {len(kg.h2t[simple_query])} targets from {len(res)} results")

In [3]:
qs.set_k(k_neighbors=50)
metrics = {
            "mrr": [],
            "hits3": [],
            "hits5": [],
            "hits10": [],
        }

for key in tqdm(kg.h2t.keys()):

    res = qs.execute_query([key], proj_mode="inter", agg_mode="union")
    founds = [ f for f in res if f in kg.h2t[key]]

    ids = torch.tensor(founds)

    for t in kg.h2t[key]:
        ranking = (ids == t)
        if ranking.sum():
            ranking = ranking.nonzero(as_tuple=True)[0]+1
            metrics['mrr'].append(1.0 / ranking.item())
            metrics['hits3'].append(1.0 if ranking <= 3 else 0.0)
            metrics['hits5'].append(1.0 if ranking <= 5 else 0.0)
            metrics['hits10'].append(1.0 if ranking <= 10 else 0.0)

# print(f"Average Recall over {len(queries)} complex queries (flat method): {np.mean(metrics['recall'])}")
print(f"Average MRR over {len(kg.h2t.keys())} simple queries: {np.mean(metrics['mrr'])}")
print(f"Average Hits@K over {len(kg.h2t.keys())} simple queries: 3: {np.mean(metrics['hits3'])}, 5: {np.mean(metrics['hits5'])}, 10: {np.mean(metrics['hits10'])}")

  0%|          | 0/174097 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [4]:
qs.set_k(k_neighbors=50, k_results=25)
metrics = {
            "mrr": [],
            "hits3": [],
            "hits5": [],
            "hits10": [],
        }

recalls = []

for query, result in tqdm(zip(queries, results), total=len(queries)):

    res = qs.execute_query(query, proj_mode="inter", agg_mode="union")
    founds = [ f for f in res if f in result[-1]]

    recalls.append( (len(founds) / len(result[-1])) if len(result[-1]) > 0 else 0.0 )

    # ids = torch.tensor(founds)

    # for t in kg.h2t[key]:
    #     ranking = (ids == t)
    #     if ranking.sum():
    #         ranking = ranking.nonzero(as_tuple=True)[0]+1
    #         metrics['mrr'].append(1.0 / ranking.item())
    #         metrics['hits3'].append(1.0 if ranking <= 3 else 0.0)
    #         metrics['hits5'].append(1.0 if ranking <= 5 else 0.0)
    #         metrics['hits10'].append(1.0 if ranking <= 10 else 0.0)

print(f"Average Recall over {len(queries)} complex queries (2p1): {np.mean(recalls)}")
# print(f"Average MRR over {len(kg.h2t.keys())} complex queries (2pi): {np.mean(metrics['mrr'])}")
# print(f"Average Hits@K over {len(kg.h2t.keys())} complex queries (2pi): 3: {np.mean(metrics['hits3'])}, 5: {np.mean(metrics['hits5'])}, 10: {np.mean(metrics['hits10'])}")

  0%|          | 0/3803 [00:00<?, ?it/s]

Average Recall over 3803 complex queries (2p1): 0.530452474183132
