In [1]:
import os
import random
import torch
from tqdm import trange
import time
import numpy as np
import pickle
import json
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from biencoder import *
import data_process as data
from transformers import get_linear_schedule_with_warmup
from transformers import  BertTokenizerFast, BertModel
from queue import PriorityQueue

In [2]:
embeddings_folder = "entity_embdes_2.pkl"
input_batch_size = 256
model_path = 'C:\\Users\\aydxng\\Documents\\ds-fundingbodies-linkingcomponent-masterthesis\\Thesis\\LinkerRound1\\ctxt_model_2_epoch_0.pt'
model_path_m = 'C:\\Users\\aydxng\\Documents\\ds-fundingbodies-linkingcomponent-masterthesis\\Thesis\\LinkerRound1\\m_2_epoch_0.pt'
seed = 0
device='cuda'
train_fname = "biencoder_monitor.jsonl"

In [3]:
with open(embeddings_folder,"rb") as f:
    entity_emebeddings=pickle.load(f)

In [4]:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x29ef4a85050>

In [5]:
from data_process import *
def process_mention_data_2(samples,tokenizer):
    
    max_context_length=64
    mention_key="mention"
    context_key="context"
    ent_start_token="[unused0]"
    ent_end_token="[unused1]"
    
    processed_samples = []
    all_samples = []
    iter_ = samples

    for idx, sample in enumerate(iter_):
        context_tokens = get_context_representation(sample,tokenizer,max_context_length,mention_key,context_key,ent_start_token,ent_end_token)
                        
        record = {"context": context_tokens}
            
        processed_samples.append(record)
        all_samples.append(sample)
        
    context_vecs = torch.tensor(
        select_field(processed_samples, "context", "ids"), dtype=torch.long,
    )
    data = {
        "context_vecs": context_vecs,
        "sample":all_samples
    }

    tensor_data = TensorDataset(context_vecs)
    return data, tensor_data

In [6]:
ctxt_model = torch.load(model_path).to(device)
m = torch.load(model_path_m).to('cpu')
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
ctxt_model.eval()
m.eval()



Linear(in_features=768, out_features=2, bias=False)

In [7]:
# Load train data 
train_samples = []
with open(train_fname, mode="r", encoding="utf-8") as file:
    for line in file:
        train_samples.append(json.loads(line.strip()))
print(len(train_samples))

4635


In [8]:
correct_entities = [x['label_id'] for x in train_samples]

In [9]:
train_data, train_tensor_data = process_mention_data_2(train_samples,tokenizer)

train_sampler = SequentialSampler(train_tensor_data)
train_dataloader = DataLoader(train_tensor_data, sampler=train_sampler, batch_size=input_batch_size)

In [10]:
ctxt_model.eval()
print(len(train_dataloader))
mention_embeddings = []
with torch.no_grad():
    start = time.time()
    for step, context_input in enumerate(train_dataloader):
        if step%10==0:
            print("Step: ",step," ",time.time()-start)
        context_input = context_input[0]
        this_batch= context_input.size(0)
        ctxt_rep = ctxt_model(context_input.to(device))[0][:,0,:]
        for i in range(this_batch):
            mention_embeddings.append(ctxt_rep[i].cpu().detach().numpy())

19
Step:  0   0.0020003318786621094
Step:  10   26.73799729347229


In [11]:
print(len(mention_embeddings))
print(len(correct_entities))

4635
4635


In [12]:
list(m.parameters())

[Parameter containing:
 tensor([[-0.0239, -0.0422, -0.0118,  ..., -0.0118, -0.0314,  0.0001],
         [ 0.0168,  0.0165,  0.0089,  ...,  0.0202,  0.0258, -0.0039]],
        requires_grad=True)]

In [13]:
#Param for positive class
m_second_param = list(m.parameters())[0][1].detach().numpy()

In [14]:
num_hard_negs=10

In [15]:
hard_negatives = []
predictions = []
score_predictions = []
corr_entity_scores = []
start = time.time()
for i in range(len(mention_embeddings)):
    if i%100 == 0:
        print(i, " ",time.time()-start)
    this_ment_embed = mention_embeddings[i]
    this_corr_ent = correct_entities[i]
    #For NIL mentions this score is set to 0 for the negative mining strategy
    score_corr_ent = 0.
    if this_corr_ent is not None:
        this_corr_ent = int(this_corr_ent)
        score_corr_ent =  1/(1 + np.exp(-np.sum(np.multiply(np.multiply(this_ment_embed,entity_emebeddings[this_corr_ent]),m_second_param))))
    #print("Correct Entity: ",this_corr_ent," score: ",score_corr_ent)
    
    link = None
    score_link = 0
    this_hard_negatives = None
    
    for k,v in entity_emebeddings.items():
        score_this =  1/(1 + np.exp(-np.sum(np.multiply(np.multiply(this_ment_embed,v),m_second_param))))
        if score_this>score_link:
            score_link = score_this
            link = k
        if k!=this_corr_ent:
            if score_this >= score_corr_ent:
                if this_hard_negatives is None:
                    this_hard_negatives = PriorityQueue(num_hard_negs)
                if this_hard_negatives.full():
                    temp = this_hard_negatives.get()
                this_hard_negatives.put((score_this,k),False)
    
    corr_entity_scores.append(score_corr_ent)
    hard_negatives.append(this_hard_negatives)
    predictions.append(link)
    score_predictions.append(score_link)
    #print("Prediction: ",link," score: ",score_link)

0   0.0
100   32.84685015678406
200   66.40477561950684
300   97.43366742134094
400   128.82752132415771
500   159.65541577339172
600   188.50035548210144
700   217.80125546455383
800   248.62711668014526
900   277.9820554256439
1000   309.2220368385315
1100   337.14903140068054
1200   368.3189971446991
1300   400.0930256843567
1400   431.1390235424042
1500   461.37198424339294
1600   490.5609850883484
1700   519.8990159034729
1800   550.4800109863281
1900   579.5689742565155
2000   609.9169683456421
2100   640.2099559307098
2200   670.5229458808899
2300   700.9999725818634
2400   730.4999628067017
2500   760.7699189186096
2600   790.6549112796783
2700   823.4849364757538
2800   852.9449265003204
2900   884.6279172897339
3000   914.2908797264099
3100   945.3879113197327
3200   976.0189082622528
3300   1006.4498980045319
3400   1036.5666904449463
3500   1066.5614838600159
3600   1095.912284374237
3700   1125.6070432662964
3800   1154.815840959549
3900   1184.3646712303162
4000   1214.21

In [16]:
hard_negatives_new = []
for item in hard_negatives:
    if item is None:
        hard_negatives_new.append(None)
    else:
        hard_negatives_new.append(item.queue)

In [17]:
print(len(train_samples))
print(len(correct_entities))
print(len(corr_entity_scores))
print(len(hard_negatives_new))
print(len(predictions))
print(len(score_predictions))

4635
4635
4635
4635
4635
4635


In [18]:
with open("hard_negatives_monitor_round_2.pkl","wb") as f:
    pickle.dump(hard_negatives_new,f)

In [19]:
num_correct = 0
num_all = 0
for i in range(len(train_samples)):
    if correct_entities[i] is not None:
        num_all +=1
        if hard_negatives_new[i] is None:
            num_correct += 1

In [20]:
num_correct

3237

In [21]:
num_all

3967

In [22]:
num_correct/num_all

0.8159818502646836

In [25]:
PRF_debug(correct_entities, predictions)

Micro Precision:  0.6997
Micro Recall:  0.8175
Micro F1 Score:  0.754
Both not NIL and same:  3243
Both not NIL and different:  724
Correct is NIL, linked is not NIL:  668
Correct is not NIL, linked is NIL:  0
Both are NIL:  0


In [42]:
threshold = 0.66
thresholded_predictions = []
for i in range(len(train_samples)):
    if score_predictions[i] >=threshold:
        thresholded_predictions.append(predictions[i])
    else:
        thresholded_predictions.append(None)

In [43]:
PRF_debug(correct_entities, thresholded_predictions)

Micro Precision:  0.8066
Micro Recall:  0.7862
Micro F1 Score:  0.7963
Both not NIL and same:  3119
Both not NIL and different:  529
Correct is NIL, linked is not NIL:  219
Correct is not NIL, linked is NIL:  319
Both are NIL:  449


In [24]:
def PRF_debug(correct, linked, print_err = False):
    #Same function, also outputs the individual numbers
    #Case 1: Both not NIL and same -> tp             
    #Case 2: Both not NIL and different -> fp and fn
    #Case 3: Correct is NIL, linked is not NIL -> fp
    #Case 4: Correct is not NIL, linked is NIL ->fn
    #Case 5: Both are NIL -> nothing                
    
    #Stores: micro precision, micro recall, micro f1 score,
    #macro precision, macro recall, macro f1 score
    metrics = []
    
    tp = []
    fp = []
    fn = []
    
    case1 = 0
    case2 = 0
    case3 = 0
    case4 = 0
    case5 = 0
    
    for i in range(len(correct)):
        correct_entity = correct[i]
        linked_entity = linked[i]
        
        if correct_entity is not None and linked_entity is not None:
            #When read from csv some entity ids were converted to float for some reason
            linked_entity = str(int(linked_entity))
            #Catch cases where there is something wrong with correct entity id
            try:
                correct_entity = str(int(correct_entity))
            except ValueError as e:
                if print_err:
                    print('Error str: ',e)
                    print('Index: ',i)
                tp.append(0)
                fp.append(0)
                fn.append(0)
                continue
                
            #Case 1
            if str(int(correct[i])) == str(int(linked[i])):
                tp.append(1)
                fp.append(0)
                fn.append(0)
                case1 += 1
            #Case2
            else:
                tp.append(0)
                fp.append(1)
                fn.append(1)
                case2 += 1
        
        #Case 3
        elif correct_entity is None and linked_entity is not None:
            tp.append(0)
            fp.append(1)
            fn.append(0)
            case3 += 1
        
        #Case 4
        elif correct_entity is not None and linked_entity is None:
            tp.append(0)
            fp.append(0)
            fn.append(1)
            case4 += 1
        
        #Case 5
        else:
            tp.append(0)
            fp.append(0)
            fn.append(0)
            case5 += 1
            
    micro_precision = np.round(np.sum(tp)/(np.sum(tp)+np.sum(fp)+(10**-6)),4)
    micro_recall = np.round(np.sum(tp)/(np.sum(tp)+np.sum(fn)+(10**-6)),4)
    micro_f1 = (2*micro_precision*micro_recall)/(micro_precision+micro_recall+(10**-6))
        
    print('Micro Precision: ', np.round(micro_precision,4))
    print('Micro Recall: ', np.round(micro_recall,4))
    print('Micro F1 Score: ', np.round(micro_f1,4))
    
    
    print('Both not NIL and same: ', case1)           
    print('Both not NIL and different: ', case2)
    print('Correct is NIL, linked is not NIL: ', case3)
    print('Correct is not NIL, linked is NIL: ', case4)
    print('Both are NIL: ', case5)              
    


In [44]:
target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
output = torch.full([10, 64], 1.5)  # A prediction (logit)
