In [19]:
import json
import torch
from sentence_transformers import util
import os

In [20]:
big = "memoryalpha"
small = "stexpanded"

top_k = 10

mapping_file_big = "./data/triples_v2/" + big + "_mapping.json"
mapping_file_small = "./data/triples_v2/" + small + "_mapping.json"
exact_match_file = "./data/exact_match/" + big + "-" + small + ".json"

communities_big_file = "./results/communities_leiden/" + big + "/level_0.txt"
communities_small_file = "./results/communities_leiden/" + small + "/level_0.txt"

communitiy_embeddings_big_file = "./results/_community_embeddings/" + big + ".json"
communitiy_embeddings_small_file = "./results/_community_embeddings/" + small + ".json"

matched_communities_path = f"./results/_matched_communities_leiden/top_{top_k}_pairs/"

In [21]:
with open(exact_match_file) as emf, open(mapping_file_big) as mfb, open(mapping_file_small) as mfs:
    mapping_big = json.load(mfb)
    mapping_small = json.load(mfs)
    exact_match = json.load(emf)

In [22]:
small_communities = []
big_communities = []

with open(communities_big_file) as cbf, open(communities_small_file) as csf:
    
    for line in csf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        small_communities.append(numbers_set)

    for line in cbf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        big_communities.append(numbers_set)

In [23]:
with open(communitiy_embeddings_big_file) as cebf, open(communitiy_embeddings_small_file) as cesf:
    community_embeddings_big = json.load(cebf)
    community_embeddings_big = {k: list(v.values()) for k, v in community_embeddings_big.items()}
    community_embeddings_small = json.load(cesf)
    community_embeddings_small = {k: list(v.values()) for k, v in community_embeddings_small.items()}

In [24]:
# community_embeddings_big['0']

In [25]:
# embeddings_old_file = "./results/embeddings/stexpanded_lab_altlab_type_abs_comment_BAAI_bge-large-en-v1.5.json"
# with open(embeddings_old_file) as eof:
#     embeddings_old = json.load(eof)

In [26]:
# embeddings_old['0']

In [27]:
big_torch_embeds = torch.Tensor(list(community_embeddings_big.values()))
small_torch_embeds = torch.Tensor(list(community_embeddings_small.values()))

pair_top_k = util.semantic_search(big_torch_embeds, small_torch_embeds, top_k=top_k)
reverse_pair_top_k = util.semantic_search(small_torch_embeds, big_torch_embeds, top_k=top_k)

In [28]:
pair_top_k[16]

[{'corpus_id': 25, 'score': 0.9999995231628418},
 {'corpus_id': 6, 'score': 0.6928253173828125},
 {'corpus_id': 8, 'score': 0.6265830397605896},
 {'corpus_id': 2, 'score': 0.5902332663536072},
 {'corpus_id': 11, 'score': 0.574837327003479},
 {'corpus_id': 5, 'score': 0.5597947239875793},
 {'corpus_id': 1, 'score': 0.5595784783363342},
 {'corpus_id': 9, 'score': 0.5594757199287415},
 {'corpus_id': 0, 'score': 0.557192862033844},
 {'corpus_id': 12, 'score': 0.5565229058265686}]

In [29]:
reverse_pair_top_k[25]

[{'corpus_id': 16, 'score': 0.9999995231628418},
 {'corpus_id': 13, 'score': 0.5901646018028259},
 {'corpus_id': 2, 'score': 0.5879038572311401},
 {'corpus_id': 9, 'score': 0.5656415820121765},
 {'corpus_id': 6, 'score': 0.5614608526229858},
 {'corpus_id': 0, 'score': 0.5612963438034058},
 {'corpus_id': 10, 'score': 0.5603558421134949},
 {'corpus_id': 12, 'score': 0.5521239042282104},
 {'corpus_id': 4, 'score': 0.549360990524292},
 {'corpus_id': 5, 'score': 0.5492114424705505}]

In [30]:
forward_dict = dict()
backward_dict = dict()
g1_keys = list(community_embeddings_big.keys())
g2_keys = list(community_embeddings_small.keys())

for a, b in zip(g1_keys, pair_top_k):
    row_info = list()
    for element in b:
        row_info.append([g2_keys[element["corpus_id"]], element["score"]])
    forward_dict[str(a)] = row_info

for a, b in zip(g2_keys, reverse_pair_top_k):
    row_info = list()
    for element in b:
        row_info.append([g1_keys[element["corpus_id"]], element["score"]])
    backward_dict[str(a)] = row_info

In [31]:
with open(os.path.join(matched_communities_path, f"{big}-{small}_top_{str(top_k)}_pairs.json"), "w") as f:
    json.dump(forward_dict, f)

with open(os.path.join(matched_communities_path, f"{small}-{big}_top_{str(top_k)}_pairs.json"), "w") as f:
    json.dump(backward_dict, f)