# Inputs
See second cell.
* `embeddings_folder`: The filename of the output entity embeddings. Naming convention is "entity_embeds_ROUNDNUMBER.pkl". For the first round, it would be "entity_embeds_1.pkl"
* `input_batch_size`: Batch size to be used for embedding calculation. Depends on available memory.
* `device`: Whether to use GPU or CPU.
* `model_path`: Path to the entity model. For first round, this would be the path to "randomneg_cand_model.pt".
* `entity_rep_file`: Path to `entity_representations.pkl`. See the notebook `BiEncoder RandomNegative Training.ipynb` for a detailed explanation.
# Outputs
Dictionary dumped to file specified with variable `embeddings_folder`. Keys are entity IDs and values are the corresponding embedding vectors/

In [None]:
import random
import torch
import time
import numpy as np
import pickle
import json
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import  BertTokenizerFast, BertModel

In [None]:
embeddings_folder = "entity_embeds_1.pkl"
input_batch_size = 64
model_path = 'randomneg_cand_model.pt'
seed = 0
device='cuda'
entity_rep_file = 'entity_representations.pkl'

In [None]:
with open(entity_rep_file,'rb') as f:
    entity_dict=pickle.load(f)

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
def select_field(data, key1, key2=None):
    if key2 is None:
        return [example[key1] for example in data]
    else:
        return [example[key1][key2] for example in data]
def process_entity_data(entity_dict):
    
    iter_ = list(entity_dict.keys())
    processed_samples = []
    for idx, sample in enumerate(iter_):
        
       
        entity_id = int(sample)
        entity_inputs = entity_dict[sample]['ids']
        
        record = {
                "entity_id": entity_id,
                "entity_inputs": entity_inputs
        }
        
        processed_samples.append(record)
    
    cand_vecs = torch.tensor(
        select_field(processed_samples, "entity_inputs"), dtype=torch.long,
    )
        
    label_idx = torch.tensor(
        select_field(processed_samples, "entity_id"), dtype=torch.long,
    )
    data = {
        "entity_id": label_idx,
        "entity_inputs": cand_vecs
    }
    tensor_data = TensorDataset(cand_vecs, label_idx)
    return data, tensor_data

In [None]:
cand_model = torch.load(model_path).to(device)

In [None]:
data, tensor_data = process_entity_data(entity_dict)

sampler = RandomSampler(tensor_data)
dataloader = DataLoader(tensor_data, sampler=sampler, batch_size=input_batch_size)

In [None]:
cand_model.eval()
entity_emebeddings = dict()
print(len(dataloader))
with torch.no_grad():
    start = time.time()
    for step, batch in enumerate(dataloader):
        if step%10==0:
            print("Step: ",step," ",time.time()-start)
        candidate_input, e_ids  = batch
        this_batch = e_ids.size(0)
        e_ids = e_ids.cpu().detach().numpy()
        cand_rep = cand_model(candidate_input.to(device))[0][:,0,:]
        for i in range(this_batch):
            entity_emebeddings[e_ids[i]] = cand_rep[i].cpu().detach().numpy()

In [None]:
with open(embeddings_folder,"wb") as f:
    pickle.dump(entity_emebeddings,f)