In [1]:
############LIBRARIES################
import pandas as pd
import pickle
import random
import numpy as np
from torch.utils.data import DataLoader
import torch
from datasets import load_metric
from seqeval.scheme import IOB2
from seqeval.metrics import classification_report

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

In [2]:
############FUNCTIONS################
#From the example GitHub Notebook
def compute_metrics(p,id2tag):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # Remove ignored index (special tokens)
    true_predictions = [
        [id2tag[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2tag[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    print("\t\tORG Precision: ",results['_ORG']['precision'])
    print("\t\tORG Recall: ",results['_ORG']['recall'])
    print("\t\tORG F1: ",results['_ORG']['f1'])
    print("\t\tGRT Precision: ",results['_GRT']['precision'])
    print("\t\tGRT Recall: ",results['_GRT']['recall'])
    print("\t\tGRT F1: ",results['_GRT']['f1'])
    
    
#Evaluate the model on the train_monitor set
def eval_on_valid(model, train_monitor_loader,id2tag):
    #Accumulate the predictions here
    val_preds = np.zeros((0,512,5))
    #Accumulate the labels here
    val_lbls = np.zeros((0,512))
    #Accumulate the oss here
    val_loss = 0
    #Loop over minibatches
    for i_val, batch_val in enumerate(train_monitor_loader):
        print(i_val,"/",len(train_monitor_loader))
        #Get the max length in this batch and crop based on that
        seq_lens = batch_val['seq_len']
        max_len_for_batch = max(seq_lens.cpu().detach().numpy())
        #Get inputs and labels for that batch and crop
        input_ids_val = torch.tensor(batch_val['input_ids'][:,:max_len_for_batch].detach().numpy()).to(device)
        attention_mask_val = torch.tensor(batch_val['attention_mask'][:,:max_len_for_batch].detach().numpy()).to(device)
        labels_val = torch.tensor(batch_val['labels'][:,:max_len_for_batch].detach().numpy()).to(device)
        #Do a forward pass
        outputs_val = model(input_ids_val, attention_mask=attention_mask_val, labels=labels_val)
        #First index is the loss. Since the output loss is the mean over minibatch samples,
        #we multiply it with batch size. Later, we divide it by the number of samples
        val_loss += outputs_val[0].item()
        #Save the loss and labels
        these_preds = outputs_val[1].cpu().detach().numpy()
        these_labels= labels_val.cpu().detach().numpy()
        #Pad the predictions again
        new_preds = np.ones((len(input_ids_val),512,5)) * -100
        new_labels= np.ones((len(input_ids_val),512)) * -100
        new_preds[:,:max_len_for_batch,:] = these_preds
        new_labels[:,:max_len_for_batch] = these_labels
        #Store in array
        val_preds = np.concatenate([val_preds,new_preds],axis=0)
        val_lbls = np.concatenate([val_lbls,new_labels],axis=0)
    print("\tValidation Loss: ",val_loss/len(train_monitor_loader))
    p = (val_preds, val_lbls)
    print("\tValidation Results: ")
    compute_metrics(p,id2tag)
    return val_preds


#Class for funding bodies dataset
class FB_Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels,at_mask,seq_lens):
        self.encodings = encodings
        self.labels = labels
        self.at_mask = at_mask
        self.seq_lens = seq_lens

    def __getitem__(self, idx):
        item = dict()
        item['input_ids'] = torch.tensor(self.encodings[idx])
        item['attention_mask'] = torch.tensor(self.at_mask[idx])
        item['labels'] = torch.tensor(self.labels[idx])
        item['seq_len'] =self.seq_lens[idx]
        return item

    def __len__(self):
        return len(self.labels)

In [3]:
############VARIABLES################
id2tag = {0: 'I_GRT', 1: 'O', 2: 'B_GRT', 3: 'B_ORG', 4: 'I_ORG'}
tag2id = {'I_GRT':0,  'O':1,  'B_GRT':2,  'B_ORG':3, 'I_ORG':4,-100:-100}
with open("bert_validation_data.pkl","rb") as f:
    val_dataset=pickle.load(f)
    valid_withlong=pickle.load(f)
    valid=pickle.load(f)
    too_long_valid=pickle.load(f)
#Load metric for evaluation
metric = load_metric("seqeval")
#Load the model and put in evaluation mode
model = torch.load("bert_epoch_1.pt")
model.eval()

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [4]:
############RUN THE MODEL################
device = torch.device('cuda')

#Put model to device
model.to(device)

#Set batch size
batch_size=32

#Initialize the valid loader without shuffling
valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

#Get the predictions and preliminary results
with torch.no_grad():
    val_preds = eval_on_valid(model, valid_loader,id2tag)

0 / 158
1 / 158
2 / 158
3 / 158
4 / 158
5 / 158
6 / 158
7 / 158
8 / 158
9 / 158
10 / 158
11 / 158
12 / 158
13 / 158
14 / 158
15 / 158
16 / 158
17 / 158
18 / 158
19 / 158
20 / 158
21 / 158
22 / 158
23 / 158
24 / 158
25 / 158
26 / 158
27 / 158
28 / 158
29 / 158
30 / 158
31 / 158
32 / 158
33 / 158
34 / 158
35 / 158
36 / 158
37 / 158
38 / 158
39 / 158
40 / 158
41 / 158
42 / 158
43 / 158
44 / 158
45 / 158
46 / 158
47 / 158
48 / 158
49 / 158
50 / 158
51 / 158
52 / 158
53 / 158
54 / 158
55 / 158
56 / 158
57 / 158
58 / 158
59 / 158
60 / 158
61 / 158
62 / 158
63 / 158
64 / 158
65 / 158
66 / 158
67 / 158
68 / 158
69 / 158
70 / 158
71 / 158
72 / 158
73 / 158
74 / 158
75 / 158
76 / 158
77 / 158
78 / 158
79 / 158
80 / 158
81 / 158
82 / 158
83 / 158
84 / 158
85 / 158
86 / 158
87 / 158
88 / 158
89 / 158
90 / 158
91 / 158
92 / 158
93 / 158
94 / 158
95 / 158
96 / 158
97 / 158
98 / 158
99 / 158
100 / 158
101 / 158
102 / 158
103 / 158
104 / 158
105 / 158
106 / 158
107 / 158
108 / 158
109 / 158
110 / 158




		ORG Precision:  0.7845483005366727
		ORG Recall:  0.8579812924130342
		ORG F1:  0.819623302671923
		GRT Precision:  0.9459433509361498
		GRT Recall:  0.9742879746835443
		GRT F1:  0.9599064646563064


In [5]:
############PREPROCESS THE OUTPUT################

#Get predicted label index
val_preds2 = np.argmax(val_preds,axis=2)
#Get the labels
valid_labels = val_dataset.labels
#Get predicted label and discard -100 tags
val_preds_tagged = []
for i in range(len(valid_labels)):
    lbls = valid_labels[i]
    preds = val_preds2[i]
    new_preds = []
    for j in range(len(lbls)):
        lbl = lbls[j]
        pred = preds[j]
        if lbl != -100:
            new_preds.append(id2tag[pred])
    val_preds_tagged.append(new_preds)
valid['Preds'] = val_preds_tagged


#Part of validation without the split sentences
valid_ok = valid[valid.index<(len(valid_withlong)-len(too_long_valid))].copy(deep=True)
#Part of validation with the split sentences
valid_merge = valid[valid.index>=(len(valid_withlong)-len(too_long_valid))]

#Get the merged predictions for the split sentences
preds = valid_merge.groupby('ID').Preds.apply(sum)

#Extract the part that we will paste (these are the long sentences)
to_be_pasted = valid_withlong[valid_withlong.index.isin(too_long_valid)].copy(deep=True)

#Append predictions for the long sentences
new_preds = []
for index, row in to_be_pasted.iterrows():
    new_preds.append(preds[index])
to_be_pasted['Preds'] = new_preds

#Construct the new validation set by merging them
valid_new = pd.concat([valid_ok,to_be_pasted])

#Make sure we did not miss anything
print(valid_new.shape[0] == valid_withlong.shape[0])

#Get gold labels wihout the -100
new_gold = []
for index, row in valid_new.iterrows():
    new_gold.append([x for x in row['Gold_Span_Tags_IOB'] if x!=-100])
valid_new['Gold'] = new_gold

valid_new.reset_index(drop=True,inplace=True)

True


In [6]:
#####GET THE RESULTS########
print(classification_report(list(valid_new.Gold.values),list(valid_new.Preds.values),scheme=IOB2,
                           digits=5,mode='default'))



              precision    recall  f1-score   support

        _GRT    0.94594   0.97429   0.95991     10112
        _ORG    0.78444   0.85796   0.81955     16355

   micro avg    0.84387   0.90241   0.87216     26467
   macro avg    0.86519   0.91613   0.88973     26467
weighted avg    0.84614   0.90241   0.87318     26467

