In [1]:
import json
from tqdm import tqdm
import torch
from sentence_transformers import util

In [2]:
big = "memoryalpha"
small = "stexpanded"

embeddings = "dogtag_bgelarge"
communities = "leiden"

node_sim_weight = 1.0
neighbor_sim_weight = 0.5
community_sim_weight = 0.2

In [5]:
mapping_file_big = "./_input/mappings/" + big + ".json"
mapping_file_small = "./_input/mappings/" + small + ".json"

communities_big_file = "./_input/communities/" + communities + "/" + big + ".txt"
communitiy_embeddings_big_file = "./_input/community_embeddings/" + communities + "_" + embeddings + "/" + big + ".json"

node_embeddings_small_file = "./_input/node_embeddings/" + embeddings + "/" + small + ".json"
url_embeddings_small_file = "./_input/url_embeddings/" + small + ".json"

exact_match_file = "./_input/exact_match/" + big + "-" + small + ".json"
gold_pairs_file = "./_input/gold_pairs/" + big + "-" + small + ".txt"

top10pairs_file = "./_input/top10pairs/" + embeddings + "/" + small + "-" + big + ".json"

In [6]:
big_communities = []
gold_pairs = []

with open(communities_big_file) as cbf, open(gold_pairs_file) as gpf:

    for line in cbf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        big_communities.append(numbers_set)

    for line in gpf:
        numbers_list = [int(num) for num in line.strip().split(";")]
        gold_pairs.append(numbers_list)

with open(exact_match_file) as file:
    exact_match = json.load(file)

gold_not_exact = list()
for p in gold_pairs:
    if [p[0], p[1]] not in exact_match:
        gold_not_exact.append([p[0], p[1]])

with open(communitiy_embeddings_big_file) as cebf:
    community_embeddings_big = json.load(cebf)
    community_embeddings_big = {k: list(v.values()) for k, v in community_embeddings_big.items()}

with open(top10pairs_file) as file:
    top10pairs = json.load(file)

with open(node_embeddings_small_file) as nesf:
    node_embeddings_small = json.load(nesf)

with open(url_embeddings_small_file) as uesf:
    url_embeddings_small = json.load(uesf)

In [7]:
merged_node_embeddings_small = {key: node_embeddings_small[key] if key in node_embeddings_small else url_embeddings_small[key] for key in set(node_embeddings_small) | set(url_embeddings_small)}
merged_node_embeddings_small = {str(k): merged_node_embeddings_small[str(k)] for k in sorted(map(int, merged_node_embeddings_small.keys()))}

top1dict = dict()
for k, v in top10pairs.items():
    top1dict[int(k)] = int(v[0][0])

In [8]:
node_to_community_embeddings = dict()
index = 0

for community in big_communities:
    for node in community:
        node_to_community_embeddings[node] = community_embeddings_big[str(index)]
    index += 1

top1dict_reordered = dict()

for k in tqdm(top10pairs.keys()):

    node_embeds = torch.Tensor(merged_node_embeddings_small[str(k)])
    compare_list = [
        node_to_community_embeddings.get(int(item[0]))
        for item in top10pairs[str(k)]
        if int(item[0]) in node_to_community_embeddings
    ]
    big_torch_embeds = torch.Tensor(compare_list)

    community_order = util.semantic_search(node_embeds, big_torch_embeds)

    best_score = 0

    for item in community_order[0]:
        score = (
                node_sim_weight * top10pairs[str(k)][item['corpus_id']][1] +
                community_sim_weight * item['score']
        )
        id_node = top10pairs[str(k)][item['corpus_id']][0]
        if score > best_score:
            best_score = score
            top1dict_reordered[int(str(k))] = int(id_node)

100%|██████████| 15524/15524 [00:06<00:00, 2278.78it/s]


In [9]:
found = 0
all_pairs = len(gold_pairs)
# all_pairs = len(gold_not_exact)

for gold_pair in tqdm(gold_pairs):
    # for gold_pair in tqdm(gold_not_exact):
    if top1dict.get(gold_pair[1]) == gold_pair[0]:
        found += 1
    # else:
    #     print(gold_pair)

print("Gold pairs in top 1: " + str(found / all_pairs * 100) + "% (" + str(found) + ")")

100%|██████████| 1779/1779 [00:00<00:00, 766793.42it/s]

Gold pairs in top 1: 78.13378302417088% (1390)





In [10]:
found = 0
all_pairs = len(gold_pairs)
# all_pairs = len(gold_not_exact)

for gold_pair in tqdm(gold_pairs):
    # for gold_pair in tqdm(gold_not_exact):
    if top1dict_reordered.get(gold_pair[1]) == gold_pair[0]:
        found += 1

print("Gold pairs in reordered top 1: " + str(found / all_pairs * 100) + "% (" + str(found) + ")")

100%|██████████| 1779/1779 [00:00<00:00, 1299489.17it/s]

Gold pairs in reordered top 1: 78.58347386172007% (1398)



