# Inputs
* `inputs`: List of samples. Each sample should be a dictionary with fields `"mention"`, `"context_left"` and `"context_right"`. See notebook `BiEncoder RandomNegative Training.ipynb` for explanations of these fields.
* `path`: Path to files needed to run this notebook: 
    * `commonness.json` (See `Train GBM Reranker.ipynb`)
    * `link_prob.json` (See `Train GBM Reranker.ipynb`)
    * `popularity.json` (See `Train GBM Reranker.ipynb`)
    * `lgbm12.pkl` (Obtained from `Train GBM Reranker.ipynb`)
    * `entities.pkl` (See `Train GBM Reranker.ipynb`)
    * `entity_embeds.pkl`  (Run `Compute Entity Embeddings.ipynb` with the latest candidate entity encoder model)
    * `hardneg_context_model_4.pt` (Latest mention encoder model)
    * `hardneg_m_4.pt` (Latest linear scoring layer)
* `threshold`: NIL mention detection threshold. Obtained from `Train GBM Reranker.ipynb`
* `device`: Device to run the context model, "cuda" or "cpu"
* `input_batch_size`: batch size for the context model

# Output

The predictions are shown in a dataframe at the end of the notebook.

In [None]:
import random
import torch
import time
import numpy as np
from fuzzywuzzy import fuzz
import pickle
import json
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import get_linear_schedule_with_warmup
from transformers import  BertTokenizerFast, BertModel
import lightgbm as lgb
import annoy
import pandas as pd

In [None]:
# Input
inputs = [   ]
#Path containing necessary files
path = ""
#NIL mention detection threshold
threshold = 0.5
#Device and batch size to use context model.
device = "cuda"
input_batch_size=32

In [None]:
#Functions
ENT_START_TAG = "[unused0]"
ENT_END_TAG = "[unused1]"
def get_context_representation(
    sample,
    tokenizer,
    max_seq_length,
    mention_key="mention",
    context_key="context",
    ent_start_token=ENT_START_TAG,
    ent_end_token=ENT_END_TAG,
):
    # mention_tokens = [Ms] mention [Me]
    mention_tokens = []
    if sample[mention_key] and len(sample[mention_key]) > 0:
        mention_tokens = tokenizer.tokenize(sample[mention_key])
        mention_tokens = [ent_start_token] + mention_tokens + [ent_end_token]

    context_left = sample[context_key + "_left"]
    context_right = sample[context_key + "_right"]
    context_left = tokenizer.tokenize(context_left)
    context_right = tokenizer.tokenize(context_right)

    left_quota = (max_seq_length - len(mention_tokens)) // 2 - 1
    right_quota = max_seq_length - len(mention_tokens) - left_quota - 2
    left_add = len(context_left)
    right_add = len(context_right)
    if left_add <= left_quota:
        if right_add > right_quota:
            right_quota += left_quota - left_add
    else:
        if right_add <= right_quota:
            left_quota += right_quota - right_add
    
    context_tokens = (
        context_left[-left_quota:] + mention_tokens + context_right[:right_quota]
    )
    
    # mention_tokens = [CLS] left context [Ms] mention [Me] right context [SEP]
    context_tokens = ["[CLS]"] + context_tokens + ["[SEP]"]
    input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    assert len(input_ids) == max_seq_length

    return {
        "tokens": context_tokens,
        "ids": input_ids,
    }


def select_field(data, key1, key2=None):
    if key2 is None:
        return [example[key1] for example in data]
    else:
        return [example[key1][key2] for example in data]
def process_mention_data_2(samples,tokenizer):
    
    max_context_length=64
    mention_key="mention"
    context_key="context"
    ent_start_token="[unused0]"
    ent_end_token="[unused1]"
    
    processed_samples = []
    all_samples = []
    iter_ = samples

    for idx, sample in enumerate(iter_):
        context_tokens = get_context_representation(sample,tokenizer,max_context_length,mention_key,context_key,ent_start_token,ent_end_token)
                        
        record = {"context": context_tokens}
            
        processed_samples.append(record)
        all_samples.append(sample)
        
    context_vecs = torch.tensor(
        select_field(processed_samples, "context", "ids"), dtype=torch.long,
    )
    data = {
        "context_vecs": context_vecs,
        "sample":all_samples
    }

    tensor_data = TensorDataset(context_vecs)
    return data, tensor_data
def get_thresholded_preds(th,pred,scores):
    thresholded_preds = []
    for i in range(len(pred)):
        if scores[i]>=th:
            thresholded_preds.append(str(pred[i]))
        else:
            thresholded_preds.append('None')
    return thresholded_preds

In [None]:
#Load files
with open(path+'commonness.json','r',encoding='utf-8') as f:
    commonness = json.load(f)
with open(path+'link_prob.json','r',encoding='utf-8') as f:
    link_probability = json.load(f)
with open(path+'popularity.json','r',encoding='utf-8') as f:
    popularity = json.load(f)
with open(path+'lgbm12.pkl','rb') as f:
    model_lgb = pickle.load(f)
with open(path+'entity_embeds_4.pkl',"rb") as f:
    entity_emebeddings=pickle.load(f)
with open(path+'entities.pkl','rb') as f:
    entity_labels=pickle.load(f)
ctxt_model = torch.load(path+'hardneg_ctxt_model_4.pt').to(device)
m = torch.load(path+'hardneg_m_4.pt').to('cpu')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
ctxt_model.eval()
m.eval()

In [None]:
#Create dataset
train_data, train_tensor_data = process_mention_data_2(inputs,tokenizer)

train_sampler = SequentialSampler(train_tensor_data)
train_dataloader = DataLoader(train_tensor_data, sampler=train_sampler, batch_size=input_batch_size)

In [None]:
#Get mention embeddings
ctxt_model.eval()
print(len(train_dataloader))
mention_embeddings = []
with torch.no_grad():
    start = time.time()
    for step, context_input in enumerate(train_dataloader):
        if step%10==0:
            print("Step: ",step," ",time.time()-start)
        context_input = context_input[0]
        this_batch= context_input.size(0)
        ctxt_rep = ctxt_model(context_input.to(device))[0][:,0,:]
        for i in range(this_batch):
            mention_embeddings.append(ctxt_rep[i].cpu().detach().numpy())

In [None]:
#Build annoy index for nearest neighbor search
#Param for positive class
m_second_param = list(m.parameters())[0][1].detach().numpy()

entity_emebeddings_with_m = dict()
keys_map = dict()
ctr = 0
for k,v in entity_emebeddings.items():
    entity_emebeddings_with_m[ctr] = np.multiply(m_second_param,v)
    keys_map[ctr] = k
    ctr+=1
    
t = annoy.AnnoyIndex(768, 'dot') 

t.set_seed(0)

for k,v in entity_emebeddings_with_m.items():
    t.add_item(k, v)
t.build(1000, n_jobs=-1)

In [None]:
#Get top 12 candidates
num_cands=12

cands = []
start = time.time()
#Loop over mentions 
for i in range(len(mention_embeddings)):
    if i%100 == 0:
        print(i, " ",time.time()-start)
    #Get the mention embedding
    this_ment_embed = mention_embeddings[i]
    
    
    #Now we get the top num_hard_negs predictions
    res = t.get_nns_by_vector(this_ment_embed, num_cands, search_k=len(entity_emebeddings_with_m), include_distances=True)
    #Store entities and scores
    #Score = -dot
    returned_entities = [keys_map[x] for x in res[0]]
    scores = [1- 1/(1 + np.exp(x)) for x in res[1]]
    merged = list(zip(scores,returned_entities))
    #Sort returned instances
    merged.sort(key=lambda tup: tup[0],reverse=True) 

    cands.append(merged)

In [None]:
#Get LGBM preds
bert_scores = []
fw_scores2 = []
unique_id = []
entity_ = []
commonness_ = []
popularity_ = []
link_probability_ = []

for i in range(len(inputs)):
    candidates = cands[i]
    this_mention = inputs[i]['mention']
    for j in range(num_cands):
        #Get FW score
        this_ent_labels = entity_labels[str(candidates[j][1])]['Labels']
        
        fw_score2 = 0
        for lbl in this_ent_labels:
            fw_score2 = max(fw_score2,fuzz.token_sort_ratio(this_mention,lbl)/100)
        fw_scores2.append(fw_score2)
        
        #Get BERT score
        bert_scores.append(candidates[j][0])
        
        commonness_.append(commonness.get(this_mention.lower(),{}).get(str(candidates[j][1]),0.))
        popularity_.append(popularity.get(str(candidates[j][1]),0.))
        link_probability_.append(link_probability.get(this_mention.lower(),0.))
        
        
        unique_id.append(i)
        entity_.append(candidates[j][1])
df=pd.DataFrame({'ID':unique_id,'Commonness':commonness_,'BERT':bert_scores,
                 'Popularity':popularity_,'Link_Probability':link_probability_,
                 'FW2':fw_scores2,'Entity':entity_})
preds = model_lgb.predict(df[['Commonness', 'BERT', 'FW2','Popularity' ,'Link_Probability']]) 

df['Score'] = preds
temp = df.copy(deep=True)
temp=temp.loc[temp.groupby('ID').Score.idxmax().values][['ID','Score','Entity']]
#Get Entities
entities = temp.Entity.values
#Get scores
scores = temp.Score.values
#Apply threshold
temp['Entity']=get_thresholded_preds(threshold,entities,scores)
#Add mention and context
ment = []
ctxleft = []
ctxright = []
for item in inputs:
    ment.append(item['mention'])
    ctxleft.append(item['context_left'])
    ctxright.append(item['context_right'])
temp['Mention']=ment
temp['Left_Context']=ctxleft
temp['Right_Context']=ctxright
temp.drop(['ID'],axis=1,inplace=True)
temp.reset_index(drop=True,inplace=True)

In [None]:
#Visualize the results
temp.head()