In [1]:
import json
import torch
from sentence_transformers import util
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [2]:
small = "stexpanded"
big = "memoryalpha"
embeddings = "dogtag_bgelarge"
top = 100

In [3]:
mappings_file_small = "./_input/mappings/" + small + ".json"
mappings_file_big = "./_input/mappings/" + big + ".json"

node_embeddings_small_file = "./_input/node_embeddings/" + embeddings + "/" + small + ".json"
node_embeddings_big_file = "./_input/node_embeddings/" + embeddings + "/" + big + ".json"

dogtags_small_file = "./_input/dogtags/" + small + ".json"
dogtags_big_file = "./_input/dogtags/" + big + ".json"

exact_match_file = "./_input/exact_match/" + small + "-" + big + ".json"
gold_pairs_file = "./_input/gold_pairs/" + small + "-" + big + ".txt"

output_file = "./_input/found_pairs/" + small + "-" + big + ".txt"

In [4]:
with open(mappings_file_small) as file:
    mappings_small = {str(v): k for k, v in json.load(file).items()}
    mappings_small_reversed = {v: k for k, v in mappings_small.items()}

with open(mappings_file_big) as file:
    mappings_big = {str(v): k for k, v in json.load(file).items()}
    mappings_big_reversed = {v: k for k, v in mappings_big.items()}

with open(node_embeddings_small_file) as nesf:
    node_embeddings_small = json.load(nesf)
    node_embeddings_small = {mappings_small_reversed[k]: v for k, v in node_embeddings_small.items()}

with open(node_embeddings_big_file) as nebf:
    node_embeddings_big = json.load(nebf)
    node_embeddings_big = {mappings_big_reversed[k]: v for k, v in node_embeddings_big.items()}

with open(dogtags_small_file) as df:
    dogtags_small = json.load(df)

with open(dogtags_big_file) as df:
    dogtags_big = json.load(df)

In [5]:
node_embeddings_small_list = list()
node_ids_small_list = list()

node_embeddings_big_list = list()
node_ids_big_list = list()

for k, v in node_embeddings_small.items():
    node_ids_small_list.append(k)
    node_embeddings_small_list.append(v)

for k, v in node_embeddings_big.items():
    node_ids_big_list.append(k)
    node_embeddings_big_list.append(v)

In [10]:
tensor_small = torch.Tensor(node_embeddings_small_list)
tensor_big = torch.Tensor(node_embeddings_big_list)
node_order = util.semantic_search(tensor_small, tensor_big, top_k=top)

In [11]:
top_dict = dict()
for idx, (node_id, order) in enumerate(zip(node_ids_small_list, node_order)):
    items_list = list()
    for item in order:
        items_list.append((node_ids_big_list[item['corpus_id']], item['score']))
    top_dict[node_id] = items_list

In [12]:
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
model = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
model.eval()

XLMRobertaForSequenceClassification(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0-23): 24 x XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSdpaSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): XLMRobertaSelfOutput(
              (dense): Linear(in_features=1024, ou

In [15]:
with open(output_file, "w") as file:
    runs = 0
    for node in node_ids_small_list:
        if runs == 3:
            break
        runs += 1
        id_list = list()
        str_list = list()
        for i in range(0, top):
            id_list.append(top_dict[node][i][0])
            str_list.append(
                [
                    str(dogtags_small[mappings_small[node]]),
                    str(dogtags_big[mappings_big[top_dict[node][i][0]]])
                ]
            )

        pairs = str_list
        with torch.no_grad():
            inputs = tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
            scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
            # print(scores)

        max_index = torch.argmax(scores)
        max_index_int = int(max_index.item())
        max_value = scores[max_index]
        max_value_float = float(max_value.item())

        file.write(
            mappings_small[node] + "###" +
            mappings_big[id_list[max_index_int]] + "###" +
            str(max_value_float) + "\n"
        )
        file.flush()