In [1]:
import json
import torch
from sentence_transformers import util
import os

In [2]:
big = "memoryalpha"
small = "stexpanded"

top_k = 10

mapping_file_big = "./data/triples_v2/" + big + "_mapping.json"
mapping_file_small = "./data/triples_v2/" + small + "_mapping.json"
exact_match_file = "./data/exact_match/" + big + "-" + small + ".json"

communities_big_file = "./results/communities_leiden/" + big + "/final.txt"
communities_small_file = "./results/communities_leiden/" + small + "/final.txt"

communitiy_embeddings_big_file = "./results/_community_embeddings/" + big + ".json"
communitiy_embeddings_small_file = "./results/_community_embeddings/" + small + ".json"

matched_communities_path = f"./results/_matched_communities_leiden/top_{top_k}_pairs/"

In [3]:
with open(exact_match_file) as emf, open(mapping_file_big) as mfb, open(mapping_file_small) as mfs:
    mapping_big = json.load(mfb)
    mapping_small = json.load(mfs)
    exact_match = json.load(emf)

In [4]:
small_communities = []
big_communities = []

with open(communities_big_file) as cbf, open(communities_small_file) as csf:
    
    for line in csf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        small_communities.append(numbers_set)

    for line in cbf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        big_communities.append(numbers_set)

In [5]:
with open(communitiy_embeddings_big_file) as cebf, open(communitiy_embeddings_small_file) as cesf:
    community_embeddings_big = json.load(cebf)
    community_embeddings_big = {k: list(v.values()) for k, v in community_embeddings_big.items()}
    community_embeddings_small = json.load(cesf)
    community_embeddings_small = {k: list(v.values()) for k, v in community_embeddings_small.items()}

In [6]:
# community_embeddings_big['0']

In [7]:
# embeddings_old_file = "./results/embeddings/stexpanded_lab_altlab_type_abs_comment_BAAI_bge-large-en-v1.5.json"
# with open(embeddings_old_file) as eof:
#     embeddings_old = json.load(eof)

In [8]:
# embeddings_old['0']

In [9]:
big_torch_embeds = torch.Tensor(list(community_embeddings_big.values()))
small_torch_embeds = torch.Tensor(list(community_embeddings_small.values()))

pair_top_k = util.semantic_search(big_torch_embeds, small_torch_embeds, top_k=top_k)
reverse_pair_top_k = util.semantic_search(small_torch_embeds, big_torch_embeds, top_k=top_k)

In [10]:
pair_top_k[16]

[{'corpus_id': 83, 'score': 0.7707483172416687},
 {'corpus_id': 846, 'score': 0.7667902708053589},
 {'corpus_id': 1092, 'score': 0.7583621144294739},
 {'corpus_id': 698, 'score': 0.7561346292495728},
 {'corpus_id': 829, 'score': 0.7486347556114197},
 {'corpus_id': 807, 'score': 0.7475183010101318},
 {'corpus_id': 1, 'score': 0.7466740608215332},
 {'corpus_id': 280, 'score': 0.7466054558753967},
 {'corpus_id': 12, 'score': 0.7463914155960083},
 {'corpus_id': 278, 'score': 0.7446763515472412}]

In [11]:
reverse_pair_top_k[25]

[{'corpus_id': 330, 'score': 0.8028998970985413},
 {'corpus_id': 433, 'score': 0.7985575795173645},
 {'corpus_id': 1879, 'score': 0.7983618378639221},
 {'corpus_id': 332, 'score': 0.7968781590461731},
 {'corpus_id': 1315, 'score': 0.7959682941436768},
 {'corpus_id': 331, 'score': 0.7954684495925903},
 {'corpus_id': 1893, 'score': 0.7942922115325928},
 {'corpus_id': 1887, 'score': 0.7924901843070984},
 {'corpus_id': 478, 'score': 0.7912763357162476},
 {'corpus_id': 1899, 'score': 0.7904046773910522}]

In [12]:
forward_dict = dict()
backward_dict = dict()
g1_keys = list(community_embeddings_big.keys())
g2_keys = list(community_embeddings_small.keys())

for a, b in zip(g1_keys, pair_top_k):
    row_info = list()
    for element in b:
        row_info.append([g2_keys[element["corpus_id"]], element["score"]])
    forward_dict[str(a)] = row_info

for a, b in zip(g2_keys, reverse_pair_top_k):
    row_info = list()
    for element in b:
        row_info.append([g1_keys[element["corpus_id"]], element["score"]])
    backward_dict[str(a)] = row_info

In [13]:
with open(os.path.join(matched_communities_path, f"{big}-{small}_top_{str(top_k)}_pairs.json"), "w") as f:
    json.dump(forward_dict, f)

with open(os.path.join(matched_communities_path, f"{small}-{big}_top_{str(top_k)}_pairs.json"), "w") as f:
    json.dump(backward_dict, f)