In [None]:
#!/usr/bin/python3
import sys
sys.path.insert(0, '/home/cc/phd/KGEmbeddings/codes')

import numpy as np
import torch
import os
from collections import defaultdict
import random
from tqdm.notebook import tqdm
import pickle 

from codes.model import KGEModel
from codes.dataloader import TrainDataset, TestDataset
from codes.triplets import TripletsEngine

EMBEDDING_DIM = 512
DATA = "umls" 
MODEL_PATH = f"/home/cc/phd/KGEmbeddings/models/TransE_{DATA}_0"
# MODEL_PATH = "/home/cc/phd/KGEmbeddings/models/RotatE_FB15k_0/"
# MODEL_PATH = "/home/cc/phd/KGEmbeddings/models/RotatE_FB15k_0/"

# DICTS_DIR = '/home/cc/phd/KGEmbeddings/data/FB15k'
DICTS_DIR = f'/home/cc/phd/KGEmbeddings/data/{DATA}'

random.seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

entity_embedding = torch.from_numpy(np.load(os.path.join(MODEL_PATH, 'entity_embedding.npy')))
relation_embedding = torch.from_numpy(np.load(os.path.join(MODEL_PATH, 'relation_embedding.npy')))

number_of_entities = entity_embedding.shape[0]
number_of_relations = relation_embedding.shape[0]

args = {
    "model": "TransE",
    "hidden_dim": EMBEDDING_DIM,
    "gamma": 24.0,
    "double_entity_embedding": False,
    "double_relation_embedding": False,
    "do_train": False,
    "test_batch_size": 512,
    "cpu_num": 32,
    "cuda": True,
    "test_log_steps": 1000,
    "nentity": number_of_entities,
    "nrelation": number_of_relations,
    "mode": "tail-batch",
    "device": device
}

class DictToObject:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            setattr(self, key, value)

args = DictToObject(args)

kge_model = KGEModel(
    model_name=args.model,
    nentity=number_of_entities,
    nrelation=number_of_relations,
    hidden_dim=args.hidden_dim,
    gamma=args.gamma,
    double_entity_embedding=args.double_entity_embedding,
    double_relation_embedding=args.double_relation_embedding
).to(device)

print("Loading checkpoint...")
checkpoint = torch.load(os.path.join(MODEL_PATH, 'checkpoint'))
init_step = checkpoint['step']
kge_model.load_state_dict(checkpoint['model_state_dict'])

if args.do_train:
    current_learning_rate = checkpoint['current_learning_rate']
    warm_up_steps = checkpoint['warm_up_steps']
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# TODO: Salvare le strutture di indexing 
kg = TripletsEngine(os.path.join(DICTS_DIR), ext="txt" if DATA == "FB15k" else "csv", from_splits=True)

In [None]:
# chat con interessante discorso su questo: https://chatgpt.com/c/68c0325c-18c8-832b-b7fe-3eb459d9c9b8
# TODO: Implementare predict cont RotatE

def predict(head_id, relation_id, tail_id, entity_embeddings, relation_embeddings, mode = "tail-batch", top_k=10):
    head = entity_embeddings[head_id]
    rel = relation_embeddings[relation_id]
    tail = entity_embeddings[tail_id]

    if mode == "head-batch":
        target = tail - rel
    else:
        target = head + rel

    # L distance to all entities
    distances = torch.norm(entity_embeddings - target, p=2, dim=1)

    # - to get largest scores
    best_ids = torch.topk(-distances, top_k).indices
    return best_ids, distances[best_ids]

def flatten(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]

def intersection(list_of_lists):
    if not list_of_lists:
        return set()
    result = set(list_of_lists[0])
    for lst in list_of_lists[1:]:
        result &= set(lst)
    return result

In [None]:
n = 10000
ids = np.random.randint(0, len(kg.triplets), size=n)
# ids = list(range(n))
metrics = {
        'MRR': [],
        'HITS@1': [],
        'HITS@3': [],
        'HITS@10': [],
        'HITS@25': [],
    }
recall = []

if args.mode == 'head-batch':
    adj = {k: torch.tensor(v, device=device) for k, v in kg.t2h.items()}
else:
    adj = {k: torch.tensor(v, device=device) for k, v in kg.h2t.items()}

In [None]:
n = 10000
ids = np.random.randint(0, len(kg.triplets), size=n)
# ids = list(range(n))
metrics = {
        'MRR': [],
        'HITS@1': [],
        'HITS@3': [],
        'HITS@10': [],
        'HITS@25': [],
    }
recall = []

# triple_set = set(map(tuple, kg.triplets))

kge_model.eval()

for id in tqdm(ids):
    target_head, target_relation, target_tail = kg.triplets[id]

    # print("Target triplet:", (int(target_head), int(target_relation), int(target_tail)))

    if args.mode == 'head-batch':
        targets = kg.t2h[(target_relation, target_tail)]
        
    else:
        targets = kg.h2t[(target_head, target_relation)]

    # print("All correct answers: ", [int(t) for t in targets])

    try:
        res = kge_model.single_test_step(kge_model, adj, (target_head, target_relation, target_tail), args)
        metrics['MRR'].append(res['MRR'])
        metrics['HITS@1'].append(res['HITS@1'])
        metrics['HITS@3'].append(res['HITS@3'])
        metrics['HITS@10'].append(res['HITS@10'])
        metrics['HITS@25'].append(res['HITS@25'])

        top_ids, dists = predict(int(target_head), int(target_relation), int(target_tail), entity_embedding, relation_embedding, mode=args.mode, top_k=max(15, int(len(targets)*1.5)))

        # print(torch.isin(top_ids, torch.tensor(targets)))
        recall.append(torch.isin(top_ids, torch.tensor(targets)).sum().item() / len(targets))

    except AssertionError as error:
        print("WARNING: triple ", (target_head, target_relation, target_tail))


print(f"Average MRR over {n} random triplets: {np.mean(metrics['MRR'])}")
print(f"Average HITS@1, HITS@3, HITS@10, HITS@25 over {n} random triplets: {np.mean(metrics['HITS@1'])}, {np.mean(metrics['HITS@3'])}, {np.mean(metrics['HITS@10'])}, {np.mean(metrics['HITS@25'])}")
print(f"Average Recall over {n} random triplets: {np.mean(recall)}")

In [None]:
### OPT1
# Average MRR over 1000 random triplets: 0.8073085733254713
# Average HITS@1, HITS@3, HITS@10, HITS@25 over 1000 random triplets: 0.701, 0.904, 0.972, 0.989
# Average Recall over 1000 random triplets: 0.9757404544860918

### OPT3
# Average MRR over 1000 random triplets: 0.405323558654472
# Average HITS@1, HITS@3, HITS@10, HITS@25 over 1000 random triplets: 0.27, 0.465, 0.689, 0.805
# Average Recall over 1000 random triplets: 0.9757404544860918


### prime
# Average MRR over 1000 random triplets: 0.07644271752140451
# Average HITS@1, HITS@3, HITS@10, HITS@25 over 1000 random triplets: 0.0, 0.094, 0.197, 0.378
# Average Recall over 1000 random triplets: 0.4471293999767067

In [None]:
# def find_queries(h2t, indexing_dict, n_queries=100):
#     queries = []
#     results = []

#     for node in tqdm(indexing_dict.keys()):
#         if indexing_dict[node]['count'] < 4 or indexing_dict[node]['count'] > 500:
#             continue

#         if len(queries) >= n_queries:
#             break

#         elements = indexing_dict[node]['in']
#         relations = np.unique(elements[:, 1])
#         np.random.seed(266)
#         np.random.shuffle(relations)

#         query = []
#         result = []

#         try:
#             r1, r2 = relations[:2]
#             h1 = elements[elements[:, 1] ==  r1].squeeze()[0]
#             h2 = elements[elements[:, 1] ==  r2].squeeze()[0]

#             t1 = h2t[(h1, r1)]
#             t2 = h2t[(h2, r2)]

#             target_tails = set(t1 + t2 + [node])
#             target_tails.discard(node)

#             query.append([(h1, r1), (h2, r2)])
#             result.append(target_tails)

#             acc = np.empty((0, 2), dtype=np.int64)
#             for tt in target_tails:
#                 acc = np.vstack([acc, indexing_dict[tt]['out']])

#             h = acc[:, 0]
#             r = acc[:, 1]

#             # Condition 1: count rows per relation ---
#             unique_r, r_counts = np.unique(r, return_counts=True)
#             mask1 = r_counts >= 2   # at least 2 edges

#             # Condition 2: count distinct h per relation ---
#             # drop duplicates by (r,h)
#             unique_rh = np.unique(acc, axis=0)
#             _, rh_counts = np.unique(unique_rh[:, 1], return_counts=True)
#             mask2 = rh_counts >= 2  # at least 2 different h

#             # Align arrays
#             valid_r = np.intersect1d(unique_r[mask1], np.unique(unique_rh[:, 1])[mask2])

#             if len(valid_r) > 0:
#                 chosen_r = valid_r[0]   
#                 # or np.random.choice(valid_r)
#                 filtered_targets = np.unique(acc[r == chosen_r][:, 0])
#             else:
#                 continue

#             filtered_targets = set(filtered_targets)
#             filtered_targets.discard(node)
            
#             query.append(chosen_r)
#             result.append(set(filtered_targets))

#             queries.append(query)
#             results.append(result)
#         except:
#             continue

#     return queries, results

def find_queries(h2t, indexing_dict, n_queries=100):
    queries = []
    results = []

    for node in tqdm(indexing_dict.keys()):
        if indexing_dict[node]['count'] < 4 or indexing_dict[node]['count'] > 500:
            continue

        if len(queries) >= n_queries:
            break

        elements = indexing_dict[node]['in']
        relations = np.unique(elements[:, 1])
        np.random.seed(266)
        np.random.shuffle(relations)

        query = []
        result = []
        
        pairs = np.array(np.meshgrid(relations, relations)).T.reshape(-1, 2)

        # Remove same-element pairs if you only want different values
        pairs = pairs[pairs[:,0] != pairs[:,1]]
        np.random.shuffle(pairs)

        for pair in pairs[:min(5, len(pairs))]:
            try:
                # r1, r2 = relations[:2]
                r1, r2 = pair
                h1 = elements[elements[:, 1] ==  r1].squeeze()[0]
                h2 = elements[elements[:, 1] ==  r2].squeeze()[0]

                t1 = h2t[(h1, r1)]
                t2 = h2t[(h2, r2)]

                target_tails = set(t1 + t2 + [node])
                target_tails.discard(node)

                query.append([(h1, r1), (h2, r2)])
                result.append(target_tails)

                acc = np.empty((0, 2), dtype=np.int64)
                for tt in target_tails:
                    acc = np.vstack([acc, indexing_dict[tt]['out']])

                h = acc[:, 0]
                r = acc[:, 1]

                # Condition 1: count rows per relation ---
                unique_r, r_counts = np.unique(r, return_counts=True)
                mask1 = r_counts >= 2   # at least 2 edges

                # Condition 2: count distinct h per relation ---
                # drop duplicates by (r,h)
                unique_rh = np.unique(acc, axis=0)
                _, rh_counts = np.unique(unique_rh[:, 1], return_counts=True)
                mask2 = rh_counts >= 2  # at least 2 different h

                # Align arrays
                valid_r = np.intersect1d(unique_r[mask1], np.unique(unique_rh[:, 1])[mask2])

                if len(valid_r) > 0:
                    chosen_r = valid_r[0]   
                    # or np.random.choice(valid_r)
                    filtered_targets = np.unique(acc[r == chosen_r][:, 0])
                else:
                    continue

                filtered_targets = set(filtered_targets)
                filtered_targets.discard(node)
                
                query.append(chosen_r)
                result.append(set(filtered_targets))

                queries.append(query)
                results.append(result)
            except:
                continue

    return queries, results

In [None]:
queries, results = find_queries(h2t, indexing_dict, n_queries=10000)

In [None]:
# queries, results = find_queries(h2t, indexing_dict, n_queries=10000)

# save_dict = {
#     'queries': queries,
#     'results': results
# }

# with open('queries-set2.pkl', 'wb') as f:
#     pickle.dump(save_dict, f)

with open('queries-set2.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)

queries = loaded_dict['queries']
results = loaded_dict['results']

In [None]:
K_NEIGHBORS = 50
K_RESULTS = 25

metrics = {
    "recall": [],
    "mrr": [],
    "hits3": [],
    "hits5": [],
    "hits10": [],
}

for idx in tqdm(range(1)):
    query = queries[idx]
    result = results[idx]

    head_1, rel_1 = query[0][0]
    head_2, rel_2 = query[0][1]

    # if args.mode == 'head-batch':
    #     targets1 = r2h[(rel_1, head_1)]
    #     targets2 = r2h[(rel_2, head_2)]
    # else:
    #     targets1 = h2t[(head_1, rel_1)]
    #     targets2 = h2t[(head_2, rel_2)]

    # adapt_retrival = int(((len(targets1) + len(targets2))/2)*1.5)
    adapt_retrival = K_NEIGHBORS

    ids1, _ = predict(int(head_1), int(rel_1), int(head_1), entity_embedding, relation_embedding,
                    mode=args.mode, top_k=max(K_NEIGHBORS, adapt_retrival))

    ids2, _ = predict(int(head_2), int(rel_2), int(head_2), entity_embedding, relation_embedding,
                    mode=args.mode, top_k=max(K_NEIGHBORS, adapt_retrival))

    # print(torch.isin(ids1, torch.tensor(targets1)).sum().item() / len(targets1))
    # print(torch.isin(ids2, torch.tensor(targets2)).sum().item() / len(targets2))

    # print(targets1)
    # print(targets2)

    heads_inter = torch.from_numpy(np.intersect1d(ids1.cpu().numpy(), ids2.cpu().numpy()))
    # print(torch.isin(heads_inter, torch.tensor(targets)).sum().item() / len(targets))

    # print(targets)
    # print(heads_inter)

    final_targets = list(result[1])
    final_rel = int(query[1])
    query_finds = []

    print(heads_inter)

    for h in heads_inter:
        ids, dists = predict(int(h), final_rel, int(h), entity_embedding, relation_embedding,
                    mode=args.mode, top_k=max(K_RESULTS, int(len(final_targets)*1.5)))
        
        for t in final_targets:
            ranking = (ids == t)
            if ranking.sum():
                ranking = ranking.nonzero(as_tuple=True)[0]+1
                metrics['mrr'].append(1.0 / ranking.item())
                metrics['hits3'].append(1.0 if ranking <= 3 else 0.0)
                metrics['hits5'].append(1.0 if ranking <= 5 else 0.0)
                metrics['hits10'].append(1.0 if ranking <= 10 else 0.0)

        query_finds.append(ids.cpu().numpy())

    finals = flatten(query_finds)

    # print("Number of targets:", len(final_targets))
    # print("Number of unique candidates:", len(finals))

    # print("Final targets:", final_targets)
    number_of_founds = len([ f for f in finals if f in final_targets ])
    penalize = [0.0 for _ in range(len(final_targets) - number_of_founds)]

    metrics['mrr'].extend(penalize)
    metrics['hits3'].extend(penalize)
    metrics['hits5'].extend(penalize)
    metrics['hits10'].extend(penalize)
    metrics['recall'].append(number_of_founds / len(final_targets))

print(f"Average Recall over {len(queries)} complex queries (flat method): {np.mean(metrics['recall'])}")
print(f"Average MRR over {len(queries)} complex queries (flat method): {np.mean(metrics['mrr'])}")
print(f"Average Hits@K over {len(queries)} complex queries (flat method): {np.mean(metrics['hits3'])}, {np.mean(metrics['hits5'])}, {np.mean(metrics['hits10'])}")

Average Recall over 3803 complex queries (flat method): 0.6561470308664061
Average MRR over 3803 complex queries (flat method): 0.07793474108706189
Average Hits@K over 3803 complex queries (flat method): 0.06498228078953368, 0.102213346991583, 0.1832569004260308