# Inputs 
* `embeddings_folder`: Path to the file containing the entity embeddings that will be used.
* `input_batch_size`: Batch size for context model inference.
* `model_path`: Path to the context model that will be used.
* `model_path_m`: Path to the linear layer that will be used.
* `device`: Whether GPU or CPU will be used.
* `train_fname`: Path to `train.jsonl` or `dev.jsonl`
* `path_to_entity_pool`: Path to `entity_pool.pkl`. This is a dictionary where the keys are IDs of entities and the values are Python sets. Each set contains the entity IDs that is assumed to be pointing to the same entity with the key. Hence, each set should at least contain the key. This dictionary aims to solve the issue where there are duplicate entitiy IDs for the same entity in some Knowledge Bases. If you do not have this issue, just create a dictionary where keys are entity IDs and the values are Python sets only containing those entity IDs.
* `out_filename`: Filename of the pickle file that will contain the hard negatives. Convention is: "DATASET_hard_negatives_ROUND.pkl". For example, after first round for training data, the filename would be "train_hard_negatives_1.pkl"

# Outputs 
Hard negatives are written to the file specified by `out_filename`. They are stored in a list that has the same length with the number of instances in the input dataset. Every list contains the hard negative entities from hardest to easiest. if 3 hard negatives are going to be utilized, the first 3 entitites in the list should be used.

In [None]:
import random
import torch
import time
import numpy as np
import pickle
import json
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import get_linear_schedule_with_warmup
from transformers import  BertTokenizerFast, BertModel
from queue import PriorityQueue
import annoy

In [None]:
embeddings_folder = 'entity_embeds_1.pkl'
input_batch_size = 256
model_path = 'randomneg_ctxt_model.pt'
model_path_m = 'randomneg_m.pt'
seed = 0
device='cuda'
train_fname = "train.jsonl"
path_to_entity_pool = "entity_pool.pkl"
out_filename = 'train_hard_negatives_1.pkl'

In [None]:
with open(path_to_entity_pool,"rb") as f:
    entity_pool = pickle.load(f)
    
with open(embeddings_folder,"rb") as f:
    entity_emebeddings=pickle.load(f)

In [None]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
ENT_START_TAG = "[unused0]"
ENT_END_TAG = "[unused1]"
def get_context_representation(
    sample,
    tokenizer,
    max_seq_length,
    mention_key="mention",
    context_key="context",
    ent_start_token=ENT_START_TAG,
    ent_end_token=ENT_END_TAG,
):
    # mention_tokens = [Ms] mention [Me]
    mention_tokens = []
    if sample[mention_key] and len(sample[mention_key]) > 0:
        mention_tokens = tokenizer.tokenize(sample[mention_key])
        mention_tokens = [ent_start_token] + mention_tokens + [ent_end_token]

    context_left = sample[context_key + "_left"]
    context_right = sample[context_key + "_right"]
    context_left = tokenizer.tokenize(context_left)
    context_right = tokenizer.tokenize(context_right)

    left_quota = (max_seq_length - len(mention_tokens)) // 2 - 1
    right_quota = max_seq_length - len(mention_tokens) - left_quota - 2
    left_add = len(context_left)
    right_add = len(context_right)
    if left_add <= left_quota:
        if right_add > right_quota:
            right_quota += left_quota - left_add
    else:
        if right_add <= right_quota:
            left_quota += right_quota - right_add
    
    context_tokens = (
        context_left[-left_quota:] + mention_tokens + context_right[:right_quota]
    )
    
    # mention_tokens = [CLS] left context [Ms] mention [Me] right context [SEP]
    context_tokens = ["[CLS]"] + context_tokens + ["[SEP]"]
    input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    assert len(input_ids) == max_seq_length

    return {
        "tokens": context_tokens,
        "ids": input_ids,
    }


def select_field(data, key1, key2=None):
    if key2 is None:
        return [example[key1] for example in data]
    else:
        return [example[key1][key2] for example in data]
def process_mention_data_2(samples,tokenizer):
    
    max_context_length=64
    mention_key="mention"
    context_key="context"
    ent_start_token="[unused0]"
    ent_end_token="[unused1]"
    
    processed_samples = []
    all_samples = []
    iter_ = samples

    for idx, sample in enumerate(iter_):
        context_tokens = get_context_representation(sample,tokenizer,max_context_length,mention_key,context_key,ent_start_token,ent_end_token)
                        
        record = {"context": context_tokens}
            
        processed_samples.append(record)
        all_samples.append(sample)
        
    context_vecs = torch.tensor(
        select_field(processed_samples, "context", "ids"), dtype=torch.long,
    )
    data = {
        "context_vecs": context_vecs,
        "sample":all_samples
    }

    tensor_data = TensorDataset(context_vecs)
    return data, tensor_data

In [None]:
ctxt_model = torch.load(model_path).to(device)
m = torch.load(model_path_m).to('cpu')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
ctxt_model.eval()
m.eval()

In [None]:
# Load train data 
train_samples = []
with open(train_fname, mode="r", encoding="utf-8") as file:
    for line in file:
        train_samples.append(json.loads(line.strip()))
print(len(train_samples))

In [None]:
correct_entities = [x['label_id'] for x in train_samples]

train_data, train_tensor_data = process_mention_data_2(train_samples,tokenizer)

train_sampler = SequentialSampler(train_tensor_data)
train_dataloader = DataLoader(train_tensor_data, sampler=train_sampler, batch_size=input_batch_size)

In [None]:
ctxt_model.eval()
print(len(train_dataloader))
mention_embeddings = []
with torch.no_grad():
    start = time.time()
    for step, context_input in enumerate(train_dataloader):
        if step%10==0:
            print("Step: ",step," ",time.time()-start)
        context_input = context_input[0]
        this_batch= context_input.size(0)
        ctxt_rep = ctxt_model(context_input.to(device))[0][:,0,:]
        for i in range(this_batch):
            mention_embeddings.append(ctxt_rep[i].cpu().detach().numpy())

In [None]:
#Param for positive class
m_second_param = list(m.parameters())[0][1].detach().numpy()

entity_emebeddings_with_m = dict()
keys_map = dict()
ctr = 0
for k,v in entity_emebeddings.items():
    entity_emebeddings_with_m[ctr] = np.multiply(m_second_param,v)
    keys_map[ctr] = k
    ctr+=1
    
t = annoy.AnnoyIndex(768, 'dot') 

t.set_seed(0)

for k,v in entity_emebeddings_with_m.items():
    t.add_item(k, v)
t.build(1000, n_jobs=-1)

In [None]:
num_hard_negs=30

hard_negatives = []
start = time.time()
#Loop over mentions 
for i in range(len(mention_embeddings)):
    if i%100 == 0:
        print(i, " ",time.time()-start)
    #Get the mention embedding
    this_ment_embed = mention_embeddings[i]
    
    #Get score of corr entity for hard neg mining
    score_corr_ent = 0.
    this_corr_ent = set([None])
    if correct_entities[i] is not None:
        #Get all correct entities
        this_corr_ent = entity_pool[str(int(correct_entities[i]))]
        
    #Now we get the top num_hard_negs predictions
    res = t.get_nns_by_vector(this_ment_embed, num_hard_negs, search_k=len(entity_emebeddings_with_m), include_distances=True)
    #Store entities and scores
    #Score = -dot
    returned_entities = [keys_map[x] for x in res[0]]
    scores = [1- 1/(1 + np.exp(x)) for x in res[1]]
    merged = list(zip(scores,returned_entities))
    #Sort returned instances
    merged.sort(key=lambda tup: tup[0],reverse=True) 
    
    #Get hard negatives
    this_hard_negs = []
    if None in this_corr_ent:
        for tup in merged:
            this_hard_negs.append(tup[1])
            #if tup[0]>=0.5:
            #    this_hard_negs.append(tup[1])
            #else:
            #    break
    else:
        for tup in merged:
            #If this is not one of the corrects, it is a hard negative
            if str(tup[1]) not in this_corr_ent:
                this_hard_negs.append(tup[1])
            #If you find the correct entity stop checking
            else:
                break
    #Set hard negs to None if we cannot find any
    if len(this_hard_negs) == 0:
        this_hard_negs = []

    hard_negatives.append(this_hard_negs)

with open(out_filename,'wb') as f:
    pickle.dump(hard_negatives,f)