# Inputs

Put the inputs to the first cell.

* `path_to_bert_training_data`: Path to a pickle file that contains both training and dev data. 
    1. Prepare the data
        * The dataset should be a list of dictionaries. Each dictionary should correspond to one sample. The keys of the dictionary are "tokens", which stores a list of strings corresponding to the tokenized version of the input sentence, and "tags", the IOB tag for each token.
            ```python
            data = [
                        {"tokens":['token_a','token_b'],"tags":['tag_a','tag_b']},
                        {"tokens":['token_a','token_b','token_c'],"tags":['tag_a','tag_b','tag_c']},
                        ...
                    ]
            ```
        * Create two such variables, one for the training set and another one for the dev set (say `train_data` and `dev_data`). 
        * IOB tags: `"O", "B_ORG", "I_ORG", "B_GRT", "I_GRT"`
        * Tokenization of the input: Please make sure you tokenize the input with:
        ```python
        from transformers import PreTrainedTokenizerFast
        tokenizer = PreTrainedTokenizerFast.from_pretrained('bert-base-cased')
        ```
        If a word is split into subwords, make sure to tag it appropriately. Example:
            * word: `"word_xyz"`, tag: `"O"`, tokenized: `"word", "##_", "##xyz"`
                * Corresponding tags for the tokenized wordpieces: `"O",-100,-100`
        Hence, assign the tag to the first WordPiece, and assing the integer -100 to the rest of the wordpieces
    2. Pickle the data
    ```python
with open(path_to_bert_training_data,'wb') as f:
            pickle.dump(train_data,f)
            pickle.dump(dev_data,f)
    ```
    3. How will the data be unpickled here?
    ```python
with open(path_to_bert_training_data,'rb') as f:
            train_dataset=pickle.load(f)
            dev_dataset=pickle.load(f)
    ```
* `path_to_bert_sc`: Path to the folder containing BERT Scopus. Use "bert-base-cased" if BERT Scopus not available.

# Output
* `bert_sc_ner.pt`: Trained BERT (Scopus) NER model file

In [None]:
path_to_bert_training_data = 'bert_training_data.pkl'
path_to_bert_sc = 'bert-base-cased'

# Credits
* https://huggingface.co/transformers/custom_datasets.html#tok-ner
* https://github.com/huggingface/notebooks/blob/master/examples/token_classification.ipynb
* https://medium.com/@prakashakshay/fine-tuning-bert-model-using-pytorch-f34148d58a37
* https://mccormickml.com/2019/07/22/BERT-fine-tuning/

In [None]:
import time
import numpy as np
import random
import pickle
import torch
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from datasets import load_metric
import gc
from transformers import BertForTokenClassification, PreTrainedTokenizerFast
seed_num=0
np.random.seed(seed_num)
random.seed(seed_num)
torch.manual_seed(seed_num)
torch.cuda.manual_seed(seed_num)
torch.cuda.manual_seed_all(seed_num)

In [None]:
#From the example GitHub Notebook
def compute_metrics(p,id2tag):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # Remove ignored index (special tokens)
    true_predictions = [
        [id2tag[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id2tag[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    print("\t\tORG Precision: ",results['_ORG']['precision'])
    print("\t\tORG Recall: ",results['_ORG']['recall'])
    print("\t\tORG F1: ",results['_ORG']['f1'])
    print("\t\tGRT Precision: ",results['_GRT']['precision'])
    print("\t\tGRT Recall: ",results['_GRT']['recall'])
    print("\t\tGRT F1: ",results['_GRT']['f1'])

In [None]:
#Evaluate the model on the train_monitor set
def eval_on_valid(model, train_monitor_loader,id2tag):
    #Accumulate the predictions here
    val_preds = np.zeros((0,512,5))
    #Accumulate the labels here
    val_lbls = np.zeros((0,512))
    #Accumulate the oss here
    val_loss = 0
    #Loop over minibatches
    for i_val, batch_val in enumerate(train_monitor_loader):
        #Get the max length in this batch and crop based on that
        seq_lens = batch_val['seq_len']
        max_len_for_batch = max(seq_lens.cpu().detach().numpy())
        #Get inputs and labels for that batch and crop
        input_ids_val = torch.tensor(batch_val['input_ids'][:,:max_len_for_batch].detach().numpy()).to(device)
        attention_mask_val = torch.tensor(batch_val['attention_mask'][:,:max_len_for_batch].detach().numpy()).to(device)
        labels_val = torch.tensor(batch_val['labels'][:,:max_len_for_batch].detach().numpy()).to(device)
        #Do a forward pass
        outputs_val = model(input_ids_val, attention_mask=attention_mask_val, labels=labels_val)
        #First index is the loss. Since the output loss is the mean over minibatch samples,
        #we multiply it with batch size. Later, we divide it by the number of samples
        val_loss += outputs_val[0].item()
        #Save the loss and labels
        these_preds = outputs_val[1].cpu().detach().numpy()
        these_labels= labels_val.cpu().detach().numpy()
        #Pad the predictions again
        new_preds = np.ones((len(input_ids_val),512,5)) * -100
        new_labels= np.ones((len(input_ids_val),512)) * -100
        new_preds[:,:max_len_for_batch,:] = these_preds
        new_labels[:,:max_len_for_batch] = these_labels
        #Store in array
        val_preds = np.concatenate([val_preds,new_preds],axis=0)
        val_lbls = np.concatenate([val_lbls,new_labels],axis=0)
    print("\tValidation Loss: ",val_loss/len(train_monitor_loader))
    p = (val_preds, val_lbls)
    print("\tValidation Results: ")
    compute_metrics(p,id2tag)

In [None]:
#Class for funding bodies dataset

class FB_Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels,at_mask,seq_lens):
        self.encodings = encodings
        self.labels = labels
        self.at_mask = at_mask
        self.seq_lens = seq_lens

    def __getitem__(self, idx):
        item = dict()
        item['input_ids'] = torch.tensor(self.encodings[idx])
        item['attention_mask'] = torch.tensor(self.at_mask[idx])
        item['labels'] = torch.tensor(self.labels[idx])
        item['seq_len'] =self.seq_lens[idx]
        return item

    def __len__(self):
        return len(self.labels)
    
#Add [CLS] and [SEP] tokens, pad until "pad_len" chars.
def add_and_pad(lst,pad_len,cls,sep,pad):
    new_lst = []
    for item in lst:
        new_item = [cls] + item + [sep]
        while len(new_item) != pad_len:
            new_item.append(pad)
        new_lst.append(new_item)
    return new_lst
    
def convert_to_fb_dataset(original_input,tokenizer,max_len=512):
    tag2id = {'I_GRT':0, 'O':1,'B_GRT':2, 'B_ORG':3, 'I_ORG':4, -100:-100}
    texts = [x['tokens'] for x in original_input]
    tags = [x['tags'] for x in original_input]
    encodings = tokenizer(texts, is_split_into_words=True,add_special_tokens =False)['input_ids']
    seq_lens = [len(x)+2 for x in encodings]
    encodings = add_and_pad(encodings,max_len,101,102,0)
    attention_mask = [[0 if num==0 else 1 for num in lst]  for lst in encodings]
    labels = add_and_pad(tags,max_len,-100,-100,-100)
    labels = [[tag2id[x] for x in label]  for label in labels]
    
    return FB_Dataset(encodings, labels,attention_mask,seq_lens)

In [None]:
model = BertForTokenClassification.from_pretrained(path_to_bert_sc, num_labels=5)
tokenizer = PreTrainedTokenizerFast.from_pretrained('bert-base-cased')

In [None]:
with open(path_to_bert_training_data,'rb') as f:
    train_dataset=pickle.load(f)
    train_monitor_dataset=pickle.load(f)

"""
Example of what train_dataset (or train_monitor_dataset) should look like
train_dataset = [{"tokens":['This','work','was','supported','by','National','Institute','of','Health','with','grant','number','ABC','##12','##3','.'],
                   "tags": ['O','O','O','O','O','B_ORG','I_ORG','I_ORG','I_ORG','O','O','O','B_GRT',-100,-100,'O']} ,
                 {"tokens":['Financial','support','received','from','Dutch','Ministry','of','Health','with','Grant','234'],
                   "tags": ['O','O','O','O','B_ORG','I_ORG','I_ORG','I_ORG','O','O','B_GRT']} 
                ]
"""
train_dataset = convert_to_fb_dataset(train_dataset,tokenizer)
train_monitor_dataset = convert_to_fb_dataset(train_monitor_dataset,tokenizer)

id2tag= {0: 'I_GRT', 1: 'O', 2: 'B_GRT', 3: 'B_ORG', 4: 'I_ORG'}
metric = load_metric("seqeval")

In [None]:
#Pick the device
device = torch.device('cuda')

#Put model to device
model.to(device)

#Put model to training mode
model.train()

#Define training batch size
batch_size=8#increase this
num_epochs = 3

#Get training sample generator
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
train_monitor_loader = DataLoader(train_monitor_dataset, batch_size=batch_size, shuffle=True)

#Initialize optimizer
optim = torch.optim.AdamW(model.parameters(), lr=2e-5) 

#Determine how many steps each epoch will take
print("Steps per epoch: " ,len(train_loader))

scheduler = get_linear_schedule_with_warmup(optim, 
                                            num_warmup_steps = 50, 
                                            num_training_steps = len(train_loader)*num_epochs)

minibatch_losses = []
#Every x step, print training loss
every_x_step=500

#Loop over epochs
for epoch in range(num_epochs):
    print("Epoch: ",epoch)
    #Loop over minibatches
    #Accumulate training statistics
    train_loss = 0
    train_preds = np.zeros((0,512,5))
    train_lbls = np.zeros((0,512))
    for i, batch in enumerate(train_loader):
        gc.collect()
        start = time.time()
        #reset gradients
        optim.zero_grad()
        #Get the max length in this batch and crop based on that
        seq_lens = batch['seq_len']
        max_len_for_batch = max(seq_lens.cpu().detach().numpy())
        #get inputs
        input_ids = torch.tensor(batch['input_ids'][:,:max_len_for_batch].detach().numpy()).to(device)
        attention_mask = torch.tensor(batch['attention_mask'][:,:max_len_for_batch].detach().numpy()).to(device)
        labels = torch.tensor(batch['labels'][:,:max_len_for_batch].detach().numpy()).to(device)
        #When we call a classification model with the labels argument, the first returned element is the Cross Entropy loss between the predictions and the passed labels. 
        #Calculate loss
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        #attention_mask.detach()
        del attention_mask
        #loss is reduced by mean (so it roughly corresponds to loss of one sample)
        #https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
        #https://huggingface.co/transformers/_modules/transformers/models/bert/modeling_bert.html#BertForTokenClassification.forward
        loss = outputs[0]
        #Print loss
        train_loss+=loss.item()
        #Second index is the predictions, store them
        these_preds = outputs[1].cpu().detach().numpy()
        these_labels= labels.cpu().detach().numpy()
        del outputs
        #labels.detach()
        del labels
        #Pad the predictions again
        new_preds = np.ones((len(input_ids),512,5)) * -100
        new_labels= np.ones((len(input_ids),512)) * -100
        #input_ids.detach()
        del input_ids
        new_preds[:,:max_len_for_batch,:] = these_preds
        new_labels[:,:max_len_for_batch] = these_labels
        #Save the labels
        train_lbls = np.concatenate([train_lbls,new_labels],axis=0)
        train_preds = np.concatenate([train_preds,new_preds],axis=0)
        #backpropagation
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        #update the parameters
        optim.step()
        scheduler.step()
        end = time.time()
        minibatch_losses.append(loss.item())
        torch.cuda.empty_cache()
        #Every x step, print validation scores
        if (i+1)%every_x_step==0:          
            #Print training loss for this minibatch
            print("\tStep ",i+1,"/",len(train_loader))
            print("\t\tBatch padding: ",max_len_for_batch)
            print("\t\tMinibatch training loss: ",loss.item())
            print("\t\tTime for this minibatch: ",end-start)
            #Check val loss
            model.eval()
            with torch.no_grad():
                #Evaluate the current model on the validation set
                eval_on_valid(model, train_monitor_loader,id2tag)
            model.train()
            #Save Model
            #with open("/dbfs/mnt/els-nlp-experts1/data/Gizem/bert_epoch_"+str(epoch)+"_step_"+str(i)+'.pt','wb') as f:
            #    torch.save(model, f)
       
    print("\tApproximate Training loss for this epoch: ",train_loss/len(train_loader))
    print("\tApproximate Training results: ")
    compute_metrics((train_preds, train_lbls),id2tag)
    #Check val loss
    model.eval()
    with torch.no_grad():
        #Evaluate the current model on the validation set
        eval_on_valid(model, train_monitor_loader,id2tag)
    model.train()
with open("bert_sc_ner.pt",'wb') as f:
    torch.save(model, f)


model.eval()