In [1]:
import json
from tqdm import tqdm
import torch
from sentence_transformers import util
import numpy as np

In [2]:
big = "memoryalpha"
small = "stexpanded"
embeddings = "dogtag_bgelarge"
top = 1000

In [3]:
mappings_file_small = "./_input/mappings/" + small + ".json"
mappings_file_big = "./_input/mappings/" + big + ".json"

node_embeddings_small_file = "./_input/node_embeddings/" + embeddings + "/" + small + ".json"
node_embeddings_big_file = "./_input/node_embeddings/" + embeddings + "/" + big + ".json"

neighborhood_embeddings_small_file = "./_input/neighborhood_embeddings/" + embeddings + "/" + small + ".json"
neighborhood_embeddings_big_file = "./_input/neighborhood_embeddings/" + embeddings + "/" + big + ".json"

exact_match_file = "./_input/exact_match/" + big + "-" + small + ".json"
gold_pairs_file = "./_input/gold_pairs/" + big + "-" + small + ".txt"

In [4]:
with open(mappings_file_small) as file:
    mappings_small = json.load(file)
    mappings_small = {str(v): k for k, v in mappings_small.items()}
    mappings_small_reversed = {v: k for k, v in mappings_small.items()}

with open(mappings_file_big) as file:
    mappings_big = json.load(file)
    mappings_big = {str(v): k for k, v in mappings_big.items()}
    mappings_big_reversed = {v: k for k, v in mappings_big.items()}

with open(node_embeddings_small_file) as nesf:
    node_embeddings_small = json.load(nesf)
    node_embeddings_small = {mappings_small_reversed[k]: v for k, v in node_embeddings_small.items()}

with open(node_embeddings_big_file) as nebf:
    node_embeddings_big = json.load(nebf)
    node_embeddings_big = {mappings_big_reversed[k]: v for k, v in node_embeddings_big.items()}

with open(neighborhood_embeddings_small_file) as eesf:
    neighborhood_embeddings_small = json.load(eesf)

with open(neighborhood_embeddings_big_file) as eebf:
    neighborhood_embeddings_big = json.load(eebf)

with open(gold_pairs_file) as gpf:
    gold_pairs = []
    for line in gpf:
        numbers_list = [int(num) for num in line.strip().split(";")]
        gold_pairs.append(numbers_list)

with open(exact_match_file) as file:
    exact_match = json.load(file)

In [5]:
gold_exact = list()
gold_not_exact = list()

for p in gold_pairs:
    if [p[0], p[1]] in exact_match:
        gold_exact.append([p[0], p[1]])
    else:
        gold_not_exact.append([p[0], p[1]])

# neighborhood_embeddings_small_list = list()
# neighborhood_ids_small_list = list()

node_embeddings_small_list = list()
node_ids_small_list = list()

neighborhood_embeddings_big_list = list()
neighborhood_ids_big_list = list()

# for k, v in neighborhood_embeddings_small.items():
#     neighborhood_ids_small_list.append(k)
#     neighborhood_embeddings_small_list.append(v)

for k, v in node_embeddings_small.items():
    node_ids_small_list.append(k)
    node_embeddings_small_list.append(v)

for k, v in neighborhood_embeddings_big.items():
    neighborhood_ids_big_list.append(k)
    neighborhood_embeddings_big_list.append(v)

In [6]:
# tensor_small = torch.Tensor(neighborhood_embeddings_small_list)
tensor_small = torch.Tensor(node_embeddings_small_list)
tensor_big = torch.Tensor(neighborhood_embeddings_big_list)
neighborhood_order = util.semantic_search(tensor_small, tensor_big, top_k=top)

In [7]:
top_dict = dict()
for idx, (node_id, order) in enumerate(zip(node_ids_small_list, neighborhood_order)):
    items_list = list()
    for item in order:
        items_list.append((neighborhood_ids_big_list[item['corpus_id']], item['score']))
    top_dict[node_id] = items_list

In [9]:
node_sim_weight = 1
neighborhood_sim_weight = 0.2

top_dict_reordered = dict()
for k, v in tqdm(top_dict.items()):
    # print(mappings_small[k])
    # print("-----")
    embedding1 = node_embeddings_small[k]
    items_list = list()
    for item in v:
        embedding2 = node_embeddings_big[item[0]]
        cosine_sim = np.dot(embedding1, embedding2)
        # new_value = float(neighborhood_sim_weight * item[1] + node_sim_weight * cosine_sim)
        new_value = float(cosine_sim)
        items_list.append((item[0], new_value))
        # print(mappings_big[item[0]])
        # print(item[1], "+", cosine_sim, "=", new_value)
    top_dict_reordered[k] = items_list

100%|██████████| 15514/15514 [19:54<00:00, 12.99it/s]  


In [10]:
for key in top_dict_reordered:
    top_dict_reordered[key] = sorted(top_dict_reordered[key], key=lambda x: x[1], reverse=True)

In [14]:
print("############### SETTINGS ################")
print("From:            " + small)
print("To:              " + big)
print("Embeddings:      " + embeddings)

print("############ ALL GOLD PAIRS #############")
print("Count:           " + str(len(gold_pairs)))

found = 0
skipped = 0
all_pairs = len(gold_pairs)
# for gold_pair in tqdm(gold_pairs):
for gold_pair in gold_pairs:
    if top_dict.get(str(gold_pair[1])) is None:
        skipped += 1
        continue
    for i in range(0, 1000):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("Skipped:         " + str(skipped))
print("In Top 1000:     " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_pairs)
# for gold_pair in tqdm(gold_pairs):
for gold_pair in gold_pairs:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    for i in range(0, 100):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("In Top 100:      " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_pairs)
# for gold_pair in tqdm(gold_pairs):
for gold_pair in gold_pairs:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    for i in range(0, 10):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("In Top 10:       " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_pairs)
# for gold_pair in tqdm(gold_pairs):
for gold_pair in gold_pairs:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    if top_dict.get(str(gold_pair[1]))[0][0] == str(gold_pair[0]):
        found += 1
print("Top 1:           " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_pairs)
# for gold_pair in tqdm(gold_pairs):
for gold_pair in gold_pairs:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    if top_dict_reordered.get(str(gold_pair[1]))[0][0] == str(gold_pair[0]):
        found += 1
print("Reordered:       " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

print("############## EXACT MATCH ##############")
print("Count:           " + str(len(gold_exact)))

found = 0
skipped = 0
all_pairs = len(gold_exact)
# for gold_pair in tqdm(gold_exact):
for gold_pair in gold_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        skipped += 1
        continue
    for i in range(0, 1000):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("Skipped:         " + str(skipped))
print("In Top 1000:     " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_exact)
# for gold_pair in tqdm(gold_exact):
for gold_pair in gold_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    for i in range(0, 100):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("In Top 100:      " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_exact)
# for gold_pair in tqdm(gold_exact):
for gold_pair in gold_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    for i in range(0, 10):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("In Top 10:       " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_exact)
# for gold_pair in tqdm(gold_exact):
for gold_pair in gold_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    if top_dict.get(str(gold_pair[1]))[0][0] == str(gold_pair[0]):
        found += 1
print("Top 1:           " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_exact)
# for gold_pair in tqdm(gold_exact):
for gold_pair in gold_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    if top_dict_reordered.get(str(gold_pair[1]))[0][0] == str(gold_pair[0]):
        found += 1
print("Reordered:       " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

print("############ NOT EXACT MATCH ############")
print("Count:           " + str(len(gold_not_exact)))

found = 0
skipped = 0
all_pairs = len(gold_not_exact)
# for gold_pair in tqdm(gold_not_exact):
for gold_pair in gold_not_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        skipped += 1
        continue
    for i in range(0, 1000):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("Skipped:         " + str(skipped))
print("In Top 1000:     " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_not_exact)
# for gold_pair in tqdm(gold_not_exact):
for gold_pair in gold_not_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    for i in range(0, 100):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("In Top 100:      " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_not_exact)
# for gold_pair in tqdm(gold_not_exact):
for gold_pair in gold_not_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    for i in range(0, 10):
        if top_dict.get(str(gold_pair[1]))[i][0] == str(gold_pair[0]):
            found += 1
            break
print("In Top 10:       " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_not_exact)
# for gold_pair in tqdm(gold_not_exact):
for gold_pair in gold_not_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    if top_dict.get(str(gold_pair[1]))[0][0] == str(gold_pair[0]):
        found += 1
print("Top 1:           " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

found = 0
all_pairs = len(gold_not_exact)
# for gold_pair in tqdm(gold_not_exact):
for gold_pair in gold_not_exact:
    if top_dict.get(str(gold_pair[1])) is None:
        continue
    if top_dict_reordered.get(str(gold_pair[1]))[0][0] == str(gold_pair[0]):
        found += 1
print("Reordered:       " + str(found) + " (" + f"{found / all_pairs * 100:.5f}" + "%)")

############### SETTINGS ################
From:            stexpanded
To:              memoryalpha
Embeddings:      dogtag_bgelarge
############ ALL GOLD PAIRS #############
Count:           1779
Skipped:         0
In Top 1000:     1600 (89.93817%)
In Top 100:      1292 (72.62507%)
In Top 10:       676 (37.99888%)
Top 1:           109 (6.12704%)
Reordered:       1347 (75.71669%)
############## EXACT MATCH ##############
Count:           1617
Skipped:         0
In Top 1000:     1471 (90.97093%)
In Top 100:      1187 (73.40754%)
In Top 10:       617 (38.15708%)
Top 1:           100 (6.18429%)
Reordered:       1262 (78.04576%)
############ NOT EXACT MATCH ############
Count:           162
Skipped:         0
In Top 1000:     129 (79.62963%)
In Top 100:      105 (64.81481%)
In Top 10:       59 (36.41975%)
Top 1:           9 (5.55556%)
Reordered:       85 (52.46914%)
