# Inputs
See third cell. Most important inputs are:
* `INP_PATH`: A path.
    * Resulting models are saved here.
    * Old models should be here.
    * `entity_pool.pkl` should be here
    * `train.jsonl` and `monitor.jsonl` should be here.
    * `entity_representations.pkl` should be here.
* `base_bert_model_ctxt`: Name of the context model of previous round
* `base_bert_model_cand`: Name of the entity model of previous round
* `base_bert_model_m`: Name of the linear layer of previous round
* `num_random_neg_cands`: Number of random negative candidate entities to sample per mention
* `hard_neg_cands_train_path`: Path to hard negatives for training set (output of `Hard Negative Mining.ipynb` for training set) 
* `hard_neg_cands_monitor_path`: Path to hard negatives for dev set (output of `Hard Negative Mining.ipynb` for dev set)
* `ROUND_NUMBER`: Which training round is this? (2, 3, 4)
The rest of the inputs are explained in the comments and no change is required.

# Outputs
Trained models:
* "hardnge_ctxt_model_ROUND_NUMBER.pt"
* "hardnge_cand_model_ROUND_NUMBER.pt"
* "hardnge_m_ROUND_NUMBER.pt"

In [None]:
%pip install transformers==3.5.1

In [None]:
import random
import torch
import time
import numpy as np
import pickle
import json
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import get_linear_schedule_with_warmup,BertTokenizerFast, BertModel

In [None]:
INP_PATH = ""
#Seed
seed = 0
#Cuda or cpu
device = 'cuda'
#See above
base_bert_model_ctxt = INP_PATH+"randomneg_ctxt_model_1.pt"
base_bert_model_cand = INP_PATH+"randomneg_cand_model_1.pt"
base_m = INP_PATH+"randomneg_m_1.pt"
hard_neg_cands_train_path = INP_PATH+"train_hardnegs_1.pkl"
hard_neg_cands_valid_path =  INP_PATH+"dev_hardnegs_1.pkl"
#Max length of mention context. Default: 64
max_context_length= 64
#Max length of entity representation. Default: 256
max_cand_length = 256
#Batch size for training. Default:16
train_batch_size = 16
#Number of epochs. Default:2
num_train_epochs=2
#Batch size for evaluation. Default:256
eval_batch_size=256
#Gradient accumulation steps. Effective batch size=train_batch_sizexgrad_acc_steps. Default:4
grad_acc_steps=4
#Number of random negative candidates to sample per mention
num_random_neg_cands=2
#Value to normalize the gradients to. Default:1.0
grad_norm = 1.0
#Training round
ROUND_NUMBER = 2

In [None]:
train_fname = INP_PATH+"train.jsonl"
monitor_fname = INP_PATH+"dev.jsonl"

In [None]:
##FROM THE BLINK REPO##
ENT_START_TAG = "[unused0]"
ENT_END_TAG = "[unused1]"

def select_field(data, key1, key2=None):
    if key2 is None:
        return [example[key1] for example in data]
    else:
        return [example[key1][key2] for example in data]

def get_context_representation(
    sample,
    tokenizer,
    max_seq_length,
    mention_key="mention",
    context_key="context",
    ent_start_token=ENT_START_TAG,
    ent_end_token=ENT_END_TAG,
):
    # mention_tokens = [Ms] mention [Me]
    mention_tokens = []
    if sample[mention_key] and len(sample[mention_key]) > 0:
        mention_tokens = tokenizer.tokenize(sample[mention_key])
        mention_tokens = [ent_start_token] + mention_tokens + [ent_end_token]

    context_left = sample[context_key + "_left"]
    context_right = sample[context_key + "_right"]
    context_left = tokenizer.tokenize(context_left)
    context_right = tokenizer.tokenize(context_right)

    left_quota = (max_seq_length - len(mention_tokens)) // 2 - 1
    right_quota = max_seq_length - len(mention_tokens) - left_quota - 2
    left_add = len(context_left)
    right_add = len(context_right)
    if left_add <= left_quota:
        if right_add > right_quota:
            right_quota += left_quota - left_add
    else:
        if right_add <= right_quota:
            left_quota += right_quota - right_add
    
    context_tokens = (
        context_left[-left_quota:] + mention_tokens + context_right[:right_quota]
    )
    
    # mention_tokens = [CLS] left context [Ms] mention [Me] right context [SEP]
    context_tokens = ["[CLS]"] + context_tokens + ["[SEP]"]
    input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    assert len(input_ids) == max_seq_length

    return {
        "tokens": context_tokens,
        "ids": input_ids,
    }


def get_candidate_representation(label_idx):

    cand_tokens = entity_dict[str(label_idx)]['tokens']
    input_ids = entity_dict[str(label_idx)]['ids']
    
    return {
        "tokens": cand_tokens,
        "ids": input_ids,
    }
  
def to_bert_input(token_idx,dev_name):
    """ token_idx is a 2D tensor int.
        return token_idx, segment_idx and mask
    """
    segment_idx = None
    mask = token_idx != 0
    if dev_name =='cuda':
        segment_idx = torch.cuda.LongTensor(token_idx * 0)
        mask = torch.cuda.LongTensor(mask.long())
    else:
        segment_idx = torch.LongTensor(token_idx * 0)
        mask = torch.LongTensor(mask.long())    
    return token_idx, segment_idx, mask

In [None]:
def process_mention_data(
    samples,
    tokenizer,
    max_context_length,
    max_cand_length,
    mention_key="mention",
    context_key="context",
    ent_start_token=ENT_START_TAG,
    ent_end_token=ENT_END_TAG
):
    processed_samples = []
    iter_ = samples
    all_samples = []

    for idx, sample in enumerate(iter_):
        context_tokens = get_context_representation(
            sample,
            tokenizer,
            max_context_length,
            mention_key,
            context_key,
            ent_start_token,
            ent_end_token,
        )
        
        if sample["label_id"] is None:
          #NIL mention
          pass
        else:
            label_idx = int(sample["label_id"])
            label_tokens = get_candidate_representation(label_idx)
            
            record = {
                "context": context_tokens,
                "label": label_tokens,
                "label_idx": 1,
                "sample":sample
            }
            processed_samples.append(record)
            all_samples.append(sample)
        
        for label_idx in sample["negative_cands"]:
            label_tokens = get_candidate_representation(label_idx)
            record = {
                "context": context_tokens,
                "label": label_tokens,
                "label_idx": 0,
                "sample":sample
            }
            processed_samples.append(record)
            all_samples.append(sample)
        
    context_vecs = torch.tensor(
        select_field(processed_samples, "context", "ids"), dtype=torch.long,
    )
    cand_vecs = torch.tensor(
        select_field(processed_samples, "label", "ids"), dtype=torch.long,
    )
    label_idx = torch.tensor(
        select_field(processed_samples, "label_idx"), dtype=torch.long,
    )
    data = {
        "context_vecs": context_vecs,
        "cand_vecs": cand_vecs,
        "label_idx": label_idx,
        "sample":all_samples
    }

    tensor_data = TensorDataset(context_vecs, cand_vecs, label_idx)
    return data, tensor_data

In [None]:
#keys are string
with open(INP_PATH+'entity_representations.pkl','rb') as f:
    entity_dict=pickle.load(f)
with open(hard_neg_cands_train_path,'rb') as f:
    hard_neg_cands_train=pickle.load(f)
with open(hard_neg_cands_valid_path,'rb') as f:
    hard_neg_cands_valid=pickle.load(f)
    
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
ctxt_model = torch.load(base_bert_model_ctxt).to(device)
cand_model = torch.load(base_bert_model_cand).to(device)
m = torch.load(base_m).to(device)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

In [None]:
# Load train data 
train_samples = []
with open(train_fname, mode="r", encoding="utf-8") as file:
    for line in file:
        train_samples.append(json.loads(line.strip()))
print(len(train_samples))

In [None]:
with open(INP_PATH+'entity_pool.pkl','rb') as f:
    entity_pool=pickle.load(f)

In [None]:
for i in range(len(train_samples)):
    if i%1000==0:
        print(i)
    
    #We will store the negative samples for this entity here
    neg_samples = []
    #Get hard negatives
    hard_neg_cands = hard_neg_cands_train[i]
    #Sample from all entities
    e_ids = list(entity_dict.keys())
    #Do not sample the correct entity
    if train_samples[i]['label_id'] is not None:
        for e_id in entity_pool[train_samples[i]['label_id']]:
            e_ids.remove(e_id)
    #Add hard negatives to the negative samples
    neg_samples += hard_neg_cands
    
    #Remove the hard entity IDs so we do not sample them randomly
    for e_id in neg_samples:
        e_ids.remove(str(e_id))
    #Add random negatives
    neg_samples = neg_samples + list(np.random.choice(e_ids,num_random_neg_cands,replace=False))
   
    neg_samples = [int(x) for x in neg_samples]
    train_samples[i]['negative_cands'] = neg_samples

In [None]:
train_data, train_tensor_data = process_mention_data(
    train_samples,
    tokenizer,
    max_context_length,
    max_cand_length
)

train_sampler = RandomSampler(train_tensor_data)
train_dataloader = DataLoader(train_tensor_data, sampler=train_sampler, batch_size=train_batch_size)

In [None]:
# Load eval data
valid_samples = []
with open(monitor_fname, mode="r", encoding="utf-8") as file:
    for line in file:
        valid_samples.append(json.loads(line.strip()))
print(len(valid_samples))

for i in range(len(valid_samples)):
    if i%1000==0:
        print(i)
    
    #We will store the negative samples for this entity here
    neg_samples = []
    #Get hard negatives
    hard_neg_cands = hard_neg_cands_valid[i]
    #Sample from all entities
    e_ids = list(entity_dict.keys())
    #Do not sample the correct entity
    if valid_samples[i]['label_id'] is not None:
        for e_id in entity_pool[valid_samples[i]['label_id']]:
            e_ids.remove(e_id)
    #Add hard negatives to the negative samples
    neg_samples += hard_neg_cands
    #Remove the hard entity IDs so we do not sample them randomly
    for e_id in neg_samples:
        e_ids.remove(str(e_id))
    #Add random negatives
    neg_samples = neg_samples + list(np.random.choice(e_ids,num_random_neg_cands,replace=False))
   
    neg_samples = [int(x) for x in neg_samples]
    valid_samples[i]['negative_cands'] = neg_samples

valid_data, valid_tensor_data = process_mention_data(
    valid_samples,
    tokenizer,
    max_context_length,
    max_cand_length
)
valid_sampler = SequentialSampler(valid_tensor_data)
valid_dataloader = DataLoader(valid_tensor_data, sampler=valid_sampler, batch_size=eval_batch_size)

In [None]:
optim_cand = torch.optim.AdamW(cand_model.parameters(), lr=2e-5) 
scheduler_cand = get_linear_schedule_with_warmup(optim_cand, 
                                                 num_warmup_steps = 0, 
                                                 num_training_steps = len(train_dataloader) // grad_acc_steps * num_train_epochs)
optim_ctxt = torch.optim.AdamW(ctxt_model.parameters(), lr=2e-5) 
scheduler_ctxt = get_linear_schedule_with_warmup(optim_ctxt, 
                                                 num_warmup_steps = 0, 
                                                 num_training_steps = len(train_dataloader) // grad_acc_steps * num_train_epochs)
optim_m = torch.optim.AdamW(m.parameters(), lr=2e-5) 
scheduler_m = get_linear_schedule_with_warmup(optim_m, 
                                                 num_warmup_steps = 0, 
                                                 num_training_steps = len(train_dataloader) // grad_acc_steps * num_train_epochs)

In [None]:
ctxt_model.eval()
cand_model.eval()
m.eval()
all_loss=0
print("Number of steps: ",len(valid_dataloader))
with torch.no_grad():
    num_correct = 0
    num_all = 0
    for step, batch in enumerate(valid_dataloader):

        context_input, candidate_input, e_ids  = batch
        longest_cand = torch.max(torch.argmin(candidate_input,dim=1))
        candidate_input = candidate_input[:,:longest_cand]
        
        if step%10==0:
            print("Step:",step," longest cand ",longest_cand)
            
        context_token_idx, context_segment_idx, context_mask = to_bert_input(context_input.to(device),device)
        candidate_token_idx, candidate_segment_idx, candidate_mask = to_bert_input(candidate_input.to(device),device)
        
        context_rep = ctxt_model(context_token_idx, context_segment_idx, context_mask)[0][:,0,:]
        cand_rep = cand_model(candidate_token_idx, candidate_segment_idx, candidate_mask)[0][:,0,:]
        
        scores = context_rep.mul(cand_rep)
        scores = m(scores)
        
        loss = torch.nn.functional.cross_entropy(scores, e_ids.to(device))#,weight=torch.tensor(class_weights).to(device))
        all_loss+=loss
        outputs = np.argmax(scores.cpu().detach(), axis=1)
        outputs = np.sum(outputs.numpy() == e_ids.numpy())
        num_correct += outputs
        num_all += context_rep.size(0)
all_loss/=len(valid_dataloader)
print("Val_Loss: ",all_loss)
print("Val_Acc: ",num_correct/num_all)

In [None]:
ctxt_model.train()
cand_model.train()
m.train()

In [None]:
print('Number of steps per epoch: ',len(train_dataloader))
print('Number of steps with accumulation: ',len(train_dataloader)//grad_acc_steps)

#Reset Gradients
optim_cand.zero_grad()
optim_ctxt.zero_grad()
optim_m.zero_grad()
start=time.time()
#Loop over epocs
for epoch in range(num_train_epochs):
    print("Epoch ",epoch)
    #Store average training loss here
    avg_loss = []
    #Loop over minibatches
    for step, batch in enumerate(train_dataloader):
        #Get the batch
        context_input, candidate_input, e_ids  = batch
        
        longest_cand = torch.max(torch.argmin(candidate_input,dim=1))
        candidate_input = candidate_input[:,:longest_cand]
        
        context_token_idx, context_segment_idx, context_mask = to_bert_input(context_input.to(device),device)
        candidate_token_idx, candidate_segment_idx, candidate_mask = to_bert_input(candidate_input.to(device),device)
        #Get representations concerning the cls token
        context_rep = ctxt_model(context_token_idx, context_segment_idx, context_mask)[0][:,0,:]
        cand_rep = cand_model(candidate_token_idx, candidate_segment_idx, candidate_mask)[0][:,0,:]
        
        #Calculate scores
        scores = context_rep.mul(cand_rep)
        scores = m(scores)
        
        #Calculate loss for storing
        loss = torch.nn.functional.cross_entropy(scores, e_ids.to(device))#,weight=torch.tensor(class_weights).to(device))
        avg_loss.append(loss.item())
        
        
        #Divide loss by grad_acc_steps for backprop
        loss = loss/grad_acc_steps
        loss.backward()
        
        #Do an update if you have accumulated enough
        if (step+1)%grad_acc_steps==0:
            if (step+1)%1000==0:
                print("\tStep: ",step+1," Loss: ",avg_loss[-1]," Longest Cand: ",longest_cand," ",time.time()-start)
            #Normalize gradients
            torch.nn.utils.clip_grad_norm_(ctxt_model.parameters(), grad_norm)
            torch.nn.utils.clip_grad_norm_(cand_model.parameters(), grad_norm)
            torch.nn.utils.clip_grad_norm_(m.parameters(), grad_norm)
            #Step the optimizer and scheduler
            #Reset gradients
            optim_cand.step()
            scheduler_cand.step()
            optim_cand.zero_grad()
            optim_ctxt.step()
            optim_ctxt.zero_grad()
            scheduler_ctxt.step()
            optim_m.step()
            scheduler_m.step()
            optim_m.zero_grad()
            
    #Reset gradients at the end of epoch    
    optim_cand.zero_grad()
    optim_ctxt.zero_grad()
    optim_m.zero_grad()
    #Put model to eval mode
    ctxt_model.eval()
    cand_model.eval()
    m.eval()
    #This will store validation loss
    all_loss=0
    with torch.no_grad():
        num_correct = 0
        num_all = 0
        for step, batch in enumerate(valid_dataloader):
            context_input, candidate_input, e_ids  = batch
            
            longest_cand = torch.max(torch.argmin(candidate_input,dim=1))
            candidate_input = candidate_input[:,:longest_cand]
            
            context_token_idx, context_segment_idx, context_mask = to_bert_input(context_input.to(device),device)
            candidate_token_idx, candidate_segment_idx, candidate_mask = to_bert_input(candidate_input.to(device),device)
            context_rep = ctxt_model(context_token_idx, context_segment_idx, context_mask)[0][:,0,:]
            cand_rep = cand_model(candidate_token_idx, candidate_segment_idx, candidate_mask)[0][:,0,:]
            scores = context_rep.mul(cand_rep)
            scores = m(scores)
            loss = torch.nn.functional.cross_entropy(scores, e_ids.to(device))#,weight=torch.tensor(class_weights).to(device))
            all_loss+=loss
            outputs = np.argmax(scores.cpu().detach(), axis=1)
            outputs = np.sum(outputs.numpy() == e_ids.numpy())
            num_correct += outputs
            num_all += context_rep.size(0)
    all_loss/=len(valid_dataloader)
    print("Val_Loss: ",all_loss)
    print("Val_Acc: ",num_correct/num_all)
    print("Train_loss",np.mean(avg_loss))
    ctxt_model.train()
    cand_model.train()
    m.train()
torch.save(ctxt_model,INP_PATH+"hardneg_ctxt_model_"+str(ROUND_NUMBER)+".pt")
torch.save(cand_model,INP_PATH+"hardneg_cand_model_"+str(ROUND_NUMBER)+".pt")
torch.save(m,INP_PATH+"hardneg_m_"+str(ROUND_NUMBER)+".pt")