In [1]:
import json
import torch
from sentence_transformers import util
import os

In [2]:
big = "memoryalpha"
small = "stexpanded"

top_k = 10

mapping_file_big = "./data/triples_v2/" + big + "_mapping.json"
mapping_file_small = "./data/triples_v2/" + small + "_mapping.json"
exact_match_file = "./data/exact_match/" + big + "-" + small + ".json"

communities_big_file = "./results/communities_leiden/" + big + "/final.txt"
communities_small_file = "./results/communities_leiden/" + small + "/final.txt"

communitiy_embeddings_big_file = "./results/community_embeddings/" + big + ".json"
communitiy_embeddings_small_file = "./results/community_embeddings/" + small + ".json"

node_embeddings_big_file = "./results/embeddings/" + big + "_lab_altlab_type_abs_comment_BAAI_bge-large-en-v1.5.json"
node_embeddings_small_file = "./results/embeddings/" + small + "_lab_altlab_type_abs_comment_BAAI_bge-large-en-v1.5.json"

matched_communities_path = f"./results/matched_communities_nodes_leiden/top_{top_k}_pairs/"

In [3]:
with open(exact_match_file) as emf, open(mapping_file_big) as mfb, open(mapping_file_small) as mfs:
    mapping_big = json.load(mfb)
    mapping_small = json.load(mfs)
    exact_match = json.load(emf)

In [4]:
small_communities = []
big_communities = []

with open(communities_big_file) as cbf, open(communities_small_file) as csf:
    
    for line in csf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        small_communities.append(numbers_set)

    for line in cbf:
        numbers_set = {int(num) for num in line.strip().split(" ")}
        big_communities.append(numbers_set)

In [5]:
with open(communitiy_embeddings_big_file) as cebf, open(communitiy_embeddings_small_file) as cesf:
    community_embeddings_big = json.load(cebf)
    community_embeddings_big = {k: list(v.values()) for k, v in community_embeddings_big.items()}
    community_embeddings_small = json.load(cesf)
    community_embeddings_small = {k: list(v.values()) for k, v in community_embeddings_small.items()}

In [6]:
with open(node_embeddings_big_file) as nebf, open(node_embeddings_small_file) as nesf:
    node_embeddings_big = json.load(nebf)
    node_embeddings_small = json.load(nesf)

In [7]:
big_torch_embeds = torch.Tensor(list(community_embeddings_big.values()))
small_torch_embeds = torch.Tensor(list(community_embeddings_small.values()))
big_torch_node_embeds = torch.Tensor(list(node_embeddings_big.values()))
small_torch_node_embeds = torch.Tensor(list(node_embeddings_small.values()))

pair_top_k = util.semantic_search(big_torch_node_embeds, small_torch_embeds, top_k=top_k)
reverse_pair_top_k = util.semantic_search(small_torch_node_embeds, big_torch_embeds, top_k=top_k)

In [8]:
pair_top_k[0]

[{'corpus_id': 293, 'score': 0.7225422263145447},
 {'corpus_id': 507, 'score': 0.7115445733070374},
 {'corpus_id': 860, 'score': 0.7056660652160645},
 {'corpus_id': 82, 'score': 0.7054846286773682},
 {'corpus_id': 276, 'score': 0.7053211331367493},
 {'corpus_id': 452, 'score': 0.7018076777458191},
 {'corpus_id': 847, 'score': 0.7013729810714722},
 {'corpus_id': 870, 'score': 0.6981565952301025},
 {'corpus_id': 817, 'score': 0.6973156332969666},
 {'corpus_id': 1147, 'score': 0.6972415447235107}]

In [9]:
reverse_pair_top_k[0]

[{'corpus_id': 2158, 'score': 0.6724958419799805},
 {'corpus_id': 469, 'score': 0.6719353795051575},
 {'corpus_id': 97, 'score': 0.6714133024215698},
 {'corpus_id': 2292, 'score': 0.6701398491859436},
 {'corpus_id': 413, 'score': 0.6693156361579895},
 {'corpus_id': 1639, 'score': 0.6688380241394043},
 {'corpus_id': 319, 'score': 0.667682409286499},
 {'corpus_id': 1467, 'score': 0.6658586263656616},
 {'corpus_id': 85, 'score': 0.6651350855827332},
 {'corpus_id': 1535, 'score': 0.6643020510673523}]

In [10]:
forward_dict = dict()
backward_dict = dict()
g1_keys = list(community_embeddings_big.keys())
g2_keys = list(community_embeddings_small.keys())
g1_node_keys = list(node_embeddings_big.keys())
g2_node_keys = list(node_embeddings_small.keys())

for a, b in zip(g1_node_keys, pair_top_k):
    row_info = list()
    for element in b:
        row_info.append([g2_keys[element["corpus_id"]], element["score"]])
    forward_dict[str(a)] = row_info

for a, b in zip(g2_node_keys, reverse_pair_top_k):
    row_info = list()
    for element in b:
        row_info.append([g1_keys[element["corpus_id"]], element["score"]])
    backward_dict[str(a)] = row_info

In [11]:
with open(os.path.join(matched_communities_path, f"{big}-{small}_top_{str(top_k)}_pairs.json"), "w") as f:
    json.dump(forward_dict, f)

with open(os.path.join(matched_communities_path, f"{small}-{big}_top_{str(top_k)}_pairs.json"), "w") as f:
    json.dump(backward_dict, f)