In [1]:
import json
from tqdm import tqdm
import torch
from sentence_transformers import util

In [2]:
big = "memoryalpha"
small = "stexpanded"

mapping_file_big = "./data/triples_v2/" + big + "_mapping.json"
mapping_file_small = "./data/triples_v2/" + small + "_mapping.json"

communities_big_file = "./results/communities_leiden/" + big + "/final.txt"
communitiy_embeddings_big_file = "./results/community_embeddings/" + big + ".json"

node_embeddings_small_file = "./results/embeddings/" + small + "_lab_altlab_type_abs_comment_BAAI_bge-large-en-v1.5.json"
url_embeddings_small_file = "./results/url_embeddings_short/" + small + "_url_BAAI_bge-large-en-v1.5.json"

exact_match_file = "./data/exact_match/" + big + "-" + small + "-formatted.json"
gold_pairs_file = "./data/gold_pairs/" + big + "-" + small + "_formatted.txt"

top10pairs_file = "./results/top10pairs/" + small + "-" + big + "_top10pairs.json"

In [3]:
big_communities = []
gold_pairs = []

with open(communities_big_file) as cbf, open(gold_pairs_file) as gpf:

    for line in cbf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        big_communities.append(numbers_set)

    for line in gpf:
        numbers_list = [int(num) for num in line.strip().split(";")]
        gold_pairs.append(numbers_list)

with open(exact_match_file) as file:
    exact_match = json.load(file)

gold_not_exact = list()
for p in gold_pairs:
    if [p[0], p[1]] not in exact_match:
        gold_not_exact.append([p[0], p[1]])

with open(communitiy_embeddings_big_file) as cebf:
    community_embeddings_big = json.load(cebf)
    community_embeddings_big = {k: list(v.values()) for k, v in community_embeddings_big.items()}

with open(top10pairs_file) as file:
    top10pairs = json.load(file)

with open(node_embeddings_small_file) as nesf:
    node_embeddings_small = json.load(nesf)

with open(url_embeddings_small_file) as uesf:
    url_embeddings_small = json.load(uesf)

merged_node_embeddings_small = {key: node_embeddings_small[key] if key in node_embeddings_small else url_embeddings_small[key] for key in set(node_embeddings_small) | set(url_embeddings_small)}
merged_node_embeddings_small = {str(k): merged_node_embeddings_small[str(k)] for k in sorted(map(int, merged_node_embeddings_small.keys()))}

In [4]:
top1dict = dict()

for k, v in top10pairs.items():
    top1dict[int(k)] = int(v[0][0])

In [5]:
found = 0
all_pairs = len(gold_pairs)
# all_pairs = len(gold_not_exact)

for gold_pair in tqdm(gold_pairs):
# for gold_pair in tqdm(gold_not_exact):
    if top1dict.get(gold_pair[1]) == gold_pair[0]:
        found += 1
    else:
        print(gold_pair)

print("Gold pairs in top 1: " + str(found / all_pairs * 100) + "%")

100%|██████████| 1779/1779 [00:00<00:00, 874960.93it/s]

[6769, 10376]
[2119, 7607]
[25058, 11045]
[34463, 2440]
[11827, 24714]
[13880, 18851]
[2989, 14320]
[20240, 3826]
[1741, 988]
[9762, 2178]
[24211, 8534]
[94172, 1056]
[23303, 3930]
[5532, 3501]
[1356, 5476]
[3348, 10196]
[3906, 13771]
[7115, 18135]
[24604, 706]
[2954, 29958]
[20041, 15760]
[10222, 11011]
[20630, 3261]
[16103, 3405]
[3607, 6102]
[56273, 10628]
[4221, 1715]
[4793, 2687]
[21854, 15914]
[490, 6563]
[1320, 5587]
[4316, 5621]
[74405, 4349]
[2460, 1509]
[3431, 1708]
[76846, 3898]
[18114, 2736]
[162583, 6404]
[18799, 2053]
[135333, 28525]
[41628, 19673]
[8552, 4046]
[68249, 4939]
[5709, 3561]
[47273, 22403]
[3285, 4200]
[2014, 4778]
[9960, 14704]
[44868, 8140]
[5330, 27260]
[2616, 423]
[23533, 6558]
[13717, 1068]
[18486, 8032]
[267, 6672]
[23435, 1591]
[45046, 6474]
[31864, 7684]
[109654, 17919]
[18514, 14095]
[835, 3949]
[17892, 22911]
[6412, 17548]
[11979, 3296]
[42269, 22117]
[14741, 8133]
[2060, 6757]
[7222, 1526]
[6693, 3571]
[2124, 8542]
[3487, 27357]
[13456, 16889]
[538




In [6]:
node_to_community_embeddings = dict()
index = 0
for community in big_communities:
    for node in community:
        node_to_community_embeddings[node] = community_embeddings_big[str(index)]
    index += 1

In [7]:
node_embeds = torch.Tensor(merged_node_embeddings_small['2440'])
compare_list = [node_to_community_embeddings[int(item[0])] for item in top10pairs['0']]
big_torch_embeds = torch.Tensor(compare_list)

community_order = util.semantic_search(node_embeds, big_torch_embeds)

In [8]:
top10pairs['2440']

[['73628', 0.7748620510101318],
 ['34463', 0.7659156918525696],
 ['139583', 0.710580050945282],
 ['92236', 0.7008298635482788],
 ['22527', 0.6984464526176453],
 ['93865', 0.6762269735336304],
 ['15660', 0.6759642362594604],
 ['54670', 0.6698982119560242],
 ['9430', 0.6697186827659607],
 ['67271', 0.6663190722465515]]

In [9]:
community_order

[[{'corpus_id': 8, 'score': 0.6746936440467834},
  {'corpus_id': 1, 'score': 0.6725383996963501},
  {'corpus_id': 9, 'score': 0.6617224216461182},
  {'corpus_id': 5, 'score': 0.6589083671569824},
  {'corpus_id': 4, 'score': 0.652053713798523},
  {'corpus_id': 2, 'score': 0.6348870992660522},
  {'corpus_id': 7, 'score': 0.6225403547286987},
  {'corpus_id': 0, 'score': 0.6225403547286987},
  {'corpus_id': 6, 'score': 0.6205182075500488},
  {'corpus_id': 3, 'score': 0.6166427135467529}]]