# Implementation of the Kuribayashi BERT minus model

## libraries

In [None]:
!pip install transformers --upgrade
!pip install ipywidgets
!pip install IProgress
!pip install datasets
!pip install torch-lr-finder

In [2]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding, default_data_collator, DataCollatorWithPadding
from transformers import Trainer, TrainingArguments


import torch
import torch.nn as nn

import numpy as np

from sklearn.metrics import classification_report
from sklearn.metrics import f1_score

import datasets
from datasets import load_metric

from torch.utils.data import DataLoader

from tqdm import tqdm
from operator import itemgetter

In [3]:
print(transformers.__version__)

4.26.0


## tokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## data

In [5]:
DATA_FILE = '/notebooks/KURI-BERT/notebooks/full_formula_w_fts/Link_Identification_Task/pe_dataset_for_bert_minus_w_fts_combined_link_task.pt'
RESULTS_FOLDER = '/notebooks/KURI-BERT/results'

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load data

In [7]:
dataset = torch.load(DATA_FILE)

### preprocessing

In [14]:
MAX_LENGTH = 0
MAX_SPAN = 0

for split in ['train', 'test', 'validation']:
    
    for col_name in ['am_spans_new', 'ac_spans_new']:
        
        for x in dataset[split][col_name]:
            
            if max(x,key=itemgetter(1))[1] > MAX_SPAN:
                
                MAX_SPAN = max(x,key=itemgetter(1))[1]
                MAX_SPAN = min(MAX_SPAN, tokenizer.model_max_length - 2)
            
            if len(x) > MAX_LENGTH:
                
                MAX_LENGTH = len(x)

In [17]:
def get_padding(batch, padding_target):    
    
    if padding_target == 'am_spans':
        
        col_name = 'paragraph_am_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'ac_spans':
        
        col_name = 'paragraph_ac_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'label':
    
        col_name = 'paragraph_labels'
        padding_val = [-100] # -1 previously       
        max_length = MAX_LENGTH 

    padded_spans = []

    for idx, span in enumerate(batch[col_name]):

        padded_span = batch[col_name][idx] + (max_length - len(span)) * padding_val
        padded_spans.append(padded_span)

    return padded_spans         

In [18]:
def get_combined_spans(am_spans_ll, ac_spans_ll):
    
    spans_ll = []
    
    for am_spans, ac_spans in zip(am_spans_ll, ac_spans_ll):
        
        spans = []
        
        for am_span, ac_span in zip(am_spans, ac_spans):

            span = [am_span, ac_span]
            spans.extend(span)
            
        spans_ll.append(spans)

    return spans_ll

### tokenize 

In [19]:
def tokenize(batch):
    
    tokenized_text = tokenizer(batch['paragraph'], truncation=True, padding=True, max_length=512)
    tokenized_text['label'] = get_padding(batch, 'label')
    tokenized_text['am_spans'] = get_padding(batch, 'am_spans')
    tokenized_text['ac_spans'] = get_padding(batch, 'ac_spans')
    tokenized_text['spans'] = get_combined_spans(tokenized_text['am_spans'], tokenized_text['ac_spans'])      
    
    return tokenized_text

In [None]:
dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset['train']))

In [22]:
dataset['test'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['train'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['validation'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")

In [None]:
dataset = dataset.map(lambda batch: batch, batched=True)

In [24]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'spans', 'label'])

In [25]:
def minus_one(t):
    
    return torch.where(t == 0, 0, t-1)

def plus_one(t):
    
    return torch.where(t >= MAX_SPAN, MAX_SPAN, t+1)

## span representation function

In [30]:
def get_span_representations(outputs, spans):

    batch_size = spans.shape[0]
    nr_span_indices = spans.shape[1]
    
    
    idx_l_ams = range(0, nr_span_indices, 2)
    idx_l_acs = range(1, nr_span_indices, 2)
    
    # Add 1 to all span indices (both am and ac) to offset for the CLS token in the input_ids.
    
    am_spans = spans[:, idx_l_ams, :] + 1 
    ac_spans = spans[:, idx_l_acs, :] + 1
    
    
    am_spans_minus_one = minus_one(am_spans)
    am_spans_plus_one = plus_one(am_spans)
    
    ac_spans_minus_one = minus_one(ac_spans)
    ac_spans_plus_one = plus_one(ac_spans)
    
    
    am_spans = am_spans.flatten(start_dim=1)
    ac_spans = ac_spans.flatten(start_dim=1)
    
    am_spans_minus_one = am_spans_minus_one.flatten(start_dim=1)
    am_spans_plus_one = am_spans_plus_one.flatten(start_dim=1)
    
    ac_spans_minus_one = ac_spans_minus_one.flatten(start_dim=1)
    ac_spans_plus_one = ac_spans_plus_one.flatten(start_dim=1)
    
    
    # BERT-minus representation for AMs =====================================================================
    
    outputs_am = outputs[:,am_spans,:]
    outputs_am = torch.cat([outputs_am[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am = outputs_am.reshape(batch_size, nr_span_indices, -1)    
    
    
    outputs_am_minus_one = outputs[:,am_spans_minus_one,:]
    outputs_am_minus_one = torch.cat([outputs_am_minus_one[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am_minus_one = outputs_am_minus_one.reshape(batch_size, nr_span_indices, -1)
    
    outputs_am_plus_one = outputs[:,am_spans_plus_one,:]
    outputs_am_plus_one = torch.cat([outputs_am_plus_one[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am_plus_one = outputs_am_plus_one.reshape(batch_size, nr_span_indices, -1)    
    

    
    ## 1st term 
    
    outputs_am_first_term = torch.cat([outputs_am[:,i+1,:] - outputs_am_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1) # i + 1 here means j in kuri, i here means i in kuri
    outputs_am_first_term = outputs_am_first_term.reshape(batch_size, -1, 768)
    
    
    
    ## 2nd term
    
    outputs_am_second_term = torch.cat([outputs_am[:,i,:] - outputs_am_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
    outputs_am_second_term = outputs_am_second_term.reshape(batch_size, -1, 768)
    
    
    ## 3rd term 
    
    outputs_am_third_term = torch.cat([outputs_am_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_third_term = outputs_am_third_term.reshape(batch_size, -1, 768)
    

    ## 4rth term
    
    outputs_am_fourth_term = torch.cat([outputs_am_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
    outputs_am_fourth_term = outputs_am_fourth_term.reshape(batch_size, -1, 768)

    
    ## Concatenate the four terms    
    
    am_minus_representations = torch.cat([outputs_am_first_term, outputs_am_second_term, outputs_am_third_term, outputs_am_fourth_term], dim=-1) 
    
    
    
    # BERT-minus representation for ACs =====================================================================
    
    outputs_ac = outputs[:,ac_spans,:]
    outputs_ac = torch.cat([outputs_ac[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac = outputs_ac.reshape(batch_size, nr_span_indices, -1)
    
    
    outputs_ac_minus_one = outputs[:,ac_spans_minus_one,:]
    outputs_ac_minus_one = torch.cat([outputs_ac_minus_one[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac_minus_one = outputs_ac_minus_one.reshape(batch_size, nr_span_indices, -1)
    
    outputs_ac_plus_one = outputs[:,ac_spans_plus_one,:]
    outputs_ac_plus_one = torch.cat([outputs_ac_plus_one[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac_plus_one = outputs_ac_plus_one.reshape(batch_size, nr_span_indices, -1)   
    
    
    
    ## 1st term
    
    outputs_ac_first_term = torch.cat([outputs_ac[:,i+1,:] - outputs_ac_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_first_term = outputs_ac_first_term.reshape(batch_size, -1, 768)
    
    
    ## 2nd term
    
    outputs_ac_second_term = torch.cat([outputs_ac[:,i,:] - outputs_ac_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run
    outputs_ac_second_term = outputs_ac_second_term.reshape(batch_size, -1, 768) 
        
    
    ## 3rd term
    
    outputs_ac_third_term = torch.cat([outputs_ac_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_third_term = outputs_ac_third_term.reshape(batch_size, -1, 768)
        
    
    ## 4rth term
    
    outputs_ac_fourth_term = torch.cat([outputs_ac_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run
    outputs_ac_fourth_term = outputs_ac_fourth_term.reshape(batch_size, -1, 768)
    
    ## Concatenate the four terms
    
    
    ac_minus_representations = torch.cat([outputs_ac_first_term, outputs_ac_second_term, outputs_ac_third_term, outputs_ac_fourth_term], dim=-1)
        
    
    return am_minus_representations, ac_minus_representations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

## custom BERT model

In [31]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_am, model_ac, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
        self.intermediate_linear_am = nn.Linear(768 * 4, 768)
        self.intermediate_linear_ac = nn.Linear(768 * 4, 768)        
        
        self.model_am = model_am
        self.model_ac = model_ac
        
        self.nr_classes = nr_classes
                
        self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size, self.nr_classes)        

    def forward(self, inputs):
        
        batch_tokenized, batch_spans = inputs         
        outputs = self.first_model(batch_tokenized, output_hidden_states=True)[1][12]
        am_minus_representations, ac_minus_representations = get_span_representations(outputs, batch_spans)
        
        am_minus_representations = self.intermediate_linear_am(am_minus_representations)
        ac_minus_representations = self.intermediate_linear_ac(ac_minus_representations)
        
        output_model_am = self.model_am(inputs_embeds = am_minus_representations)[0]
        output_model_ac = self.model_ac(inputs_embeds = ac_minus_representations)[0]

        adu_representations = torch.cat([output_model_am, output_model_ac], dim=-1)
        output = self.fc(adu_representations)
        
        return output

## Run

In [32]:
NB_EPOCHS = 40
BATCH_SIZE = 24

In [None]:
first_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
first_model.load_state_dict(torch.load('/notebooks/KURI-BERT/notebooks/full_formula_w_fts/icann_finetuned_work/finetuned_model.pth'))

In [34]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [35]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [36]:
custom_model = CustomBERTKuri(first_model, model_am, model_ac, 2)

In [None]:
custom_model.to(device)

In [38]:
loss = nn.CrossEntropyLoss(ignore_index=- 100)

In [39]:
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=1.8738174228603844e-05)

In [40]:
NR_BATCHES = len(dataset['train']) / BATCH_SIZE
num_training_steps = NB_EPOCHS * NR_BATCHES
num_warmup_steps = int(0.2 * num_training_steps)

In [41]:
def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )


In [42]:
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

### create dataloaders

In [43]:
train_dataloader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=False) # xxx.
val_dataloader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=True)

In [44]:
def flatten_list(list_of_lists):
    return [x for sublist in list_of_lists for x in sublist]

In [45]:
def remove_dummy_labels(test_preds, test_labels):
    
    idxes = []
    test_labels_l = []
    for idx, val in enumerate(test_labels):
        if val != -100:
            idxes.append(idx)
            test_labels_l.append(val)
    
    test_preds_l = []
    for idx, val in enumerate(test_preds):
        for good_idx in idxes:
            if idx == good_idx:
                test_preds_l.append(val)
        
    return test_preds_l, test_labels_l

## training 

In [48]:
def train(model, loss=None, optimizer=None, train_dataloader=None, val_dataloader=None, nb_epochs=20):
    """Training loop"""

    min_f1 = -torch.inf
    train_losses = []
    val_losses = []

    # Iterrate over epochs
    for e in range(nb_epochs):

        # Training
        train_loss = 0.0

        for batch in tqdm(train_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, spans
            
            # Reset gradients to 0
            optimizer.zero_grad()

            # Forward Pass
            outputs = model(inputs)
            
            # Compute training loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            train_loss += current_loss.detach().item()

            # Compute gradients
            current_loss.backward()

            # Update weights
            optimizer.step()            
            
            del batch
        
        scheduler.step()
            
        
        # Validation
        val_loss = 0.0

        # Put model in eval mode
        model.eval()
        
        preds_l = []
        labels_l = []
        
        for batch in tqdm(val_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, spans
            
            # Forward Pass
            outputs = model(inputs)

            # Compute validation loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            val_loss += current_loss.detach().item()
            
            preds_for_f1 = torch.argmax(outputs, dim=2).flatten().tolist()
            labels_for_f1 = labels.flatten().tolist()
            
            preds_l.append(preds_for_f1)
            labels_l.append(labels_for_f1)
            
            del batch
        
        # Prints
        
        preds_l = flatten_list(preds_l)
        labels_l = flatten_list(labels_l)
        
        preds_l, labels_l = remove_dummy_labels(preds_l, labels_l)
        
        f1_score_epoch = f1_score(preds_l, labels_l, average='macro')        
        
        print(f"Epoch {e+1}/{nb_epochs} \
                \t Training Loss: {train_loss/len(train_dataloader):.3f} \
                \t Validation Loss: {val_loss/len(val_dataloader):.3f} \
                \t F1 score: {f1_score_epoch}")
        
        train_losses.append(train_loss/len(train_dataloader))
        val_losses.append(val_loss/len(val_dataloader))
        

        # Save model if val loss decreases
        if f1_score_epoch > min_f1:

            min_f1 = f1_score_epoch
            torch.save(model.first_model.state_dict(), 'first_model.pt')
            torch.save(model.model_am.state_dict(), 'model_am.pt')
            torch.save(model.model_ac.state_dict(), 'model_ac.pt')
            torch.save(model.state_dict(), 'best_model.pt')
            
    return train_losses, val_losses

In [49]:
train_losses, val_losses = train(custom_model, loss, optimizer, train_dataloader, val_dataloader, NB_EPOCHS)

100%|██████████| 46/46 [00:30<00:00,  1.53it/s]
100%|██████████| 12/12 [00:02<00:00,  4.86it/s]


Epoch 1/40                 	 Training Loss: 0.842                 	 Validation Loss: 0.839                 	 F1 score: 0.4127672001828824


100%|██████████| 46/46 [00:29<00:00,  1.56it/s]
100%|██████████| 12/12 [00:02<00:00,  5.06it/s]


Epoch 2/40                 	 Training Loss: 0.782                 	 Validation Loss: 0.733                 	 F1 score: 0.5030875557191347


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.09it/s]


Epoch 3/40                 	 Training Loss: 0.697                 	 Validation Loss: 0.662                 	 F1 score: 0.5879640229759191


100%|██████████| 46/46 [00:29<00:00,  1.56it/s]
100%|██████████| 12/12 [00:02<00:00,  4.95it/s]


Epoch 4/40                 	 Training Loss: 0.638                 	 Validation Loss: 0.619                 	 F1 score: 0.6401320363462689


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.00it/s]


Epoch 5/40                 	 Training Loss: 0.602                 	 Validation Loss: 0.600                 	 F1 score: 0.655005407438755


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.09it/s]


Epoch 6/40                 	 Training Loss: 0.572                 	 Validation Loss: 0.572                 	 F1 score: 0.6711057651000945


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  4.93it/s]


Epoch 7/40                 	 Training Loss: 0.545                 	 Validation Loss: 0.564                 	 F1 score: 0.6825833180871218


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  4.99it/s]


Epoch 8/40                 	 Training Loss: 0.520                 	 Validation Loss: 0.548                 	 F1 score: 0.6921745867768596


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.06it/s]


Epoch 9/40                 	 Training Loss: 0.497                 	 Validation Loss: 0.526                 	 F1 score: 0.7018649689406078


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  4.98it/s]


Epoch 10/40                 	 Training Loss: 0.474                 	 Validation Loss: 0.515                 	 F1 score: 0.7162955182072829


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.05it/s]


Epoch 11/40                 	 Training Loss: 0.450                 	 Validation Loss: 0.491                 	 F1 score: 0.7270431588613406


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  4.98it/s]


Epoch 12/40                 	 Training Loss: 0.427                 	 Validation Loss: 0.486                 	 F1 score: 0.7385333316425519


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  4.96it/s]


Epoch 13/40                 	 Training Loss: 0.404                 	 Validation Loss: 0.488                 	 F1 score: 0.7464098037393792


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  4.99it/s]


Epoch 14/40                 	 Training Loss: 0.379                 	 Validation Loss: 0.472                 	 F1 score: 0.7569940321994777


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.06it/s]


Epoch 15/40                 	 Training Loss: 0.354                 	 Validation Loss: 0.466                 	 F1 score: 0.765532698746858


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.05it/s]


Epoch 16/40                 	 Training Loss: 0.330                 	 Validation Loss: 0.468                 	 F1 score: 0.774405004759962


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  4.86it/s]


Epoch 17/40                 	 Training Loss: 0.307                 	 Validation Loss: 0.456                 	 F1 score: 0.7782902179182776


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.02it/s]


Epoch 18/40                 	 Training Loss: 0.279                 	 Validation Loss: 0.485                 	 F1 score: 0.787602872834812


100%|██████████| 46/46 [00:29<00:00,  1.56it/s]
100%|██████████| 12/12 [00:02<00:00,  5.10it/s]


Epoch 19/40                 	 Training Loss: 0.252                 	 Validation Loss: 0.476                 	 F1 score: 0.7809380610412926


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.03it/s]


Epoch 20/40                 	 Training Loss: 0.226                 	 Validation Loss: 0.471                 	 F1 score: 0.782536276416048


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.07it/s]


Epoch 21/40                 	 Training Loss: 0.198                 	 Validation Loss: 0.517                 	 F1 score: 0.7846787761557914


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.04it/s]


Epoch 22/40                 	 Training Loss: 0.170                 	 Validation Loss: 0.503                 	 F1 score: 0.7849264210138349


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.08it/s]


Epoch 23/40                 	 Training Loss: 0.142                 	 Validation Loss: 0.504                 	 F1 score: 0.7899340128172634


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.12it/s]


Epoch 24/40                 	 Training Loss: 0.116                 	 Validation Loss: 0.525                 	 F1 score: 0.7893092414831545


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.05it/s]


Epoch 25/40                 	 Training Loss: 0.094                 	 Validation Loss: 0.554                 	 F1 score: 0.78228794280143


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.04it/s]


Epoch 26/40                 	 Training Loss: 0.082                 	 Validation Loss: 0.651                 	 F1 score: 0.7592089371980677


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.02it/s]


Epoch 27/40                 	 Training Loss: 0.088                 	 Validation Loss: 0.725                 	 F1 score: 0.751216681075395


100%|██████████| 46/46 [00:29<00:00,  1.53it/s]
100%|██████████| 12/12 [00:02<00:00,  4.96it/s]


Epoch 28/40                 	 Training Loss: 0.121                 	 Validation Loss: 0.652                 	 F1 score: 0.7845330973570281


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.06it/s]


Epoch 29/40                 	 Training Loss: 0.066                 	 Validation Loss: 0.758                 	 F1 score: 0.7701263156038997


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  4.98it/s]


Epoch 30/40                 	 Training Loss: 0.038                 	 Validation Loss: 0.634                 	 F1 score: 0.7930984131511836


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.11it/s]


Epoch 31/40                 	 Training Loss: 0.018                 	 Validation Loss: 0.694                 	 F1 score: 0.7920476458642639


100%|██████████| 46/46 [00:29<00:00,  1.56it/s]
100%|██████████| 12/12 [00:02<00:00,  5.09it/s]


Epoch 32/40                 	 Training Loss: 0.014                 	 Validation Loss: 0.732                 	 F1 score: 0.7911012084848417


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.09it/s]


Epoch 33/40                 	 Training Loss: 0.011                 	 Validation Loss: 0.781                 	 F1 score: 0.7909649451971168


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  4.86it/s]


Epoch 34/40                 	 Training Loss: 0.007                 	 Validation Loss: 0.776                 	 F1 score: 0.7915285525642062


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.00it/s]


Epoch 35/40                 	 Training Loss: 0.006                 	 Validation Loss: 0.798                 	 F1 score: 0.7946787921132178


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.09it/s]


Epoch 36/40                 	 Training Loss: 0.005                 	 Validation Loss: 0.817                 	 F1 score: 0.7919155048584706


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  5.00it/s]


Epoch 37/40                 	 Training Loss: 0.004                 	 Validation Loss: 0.841                 	 F1 score: 0.7904348022394799


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  4.97it/s]


Epoch 38/40                 	 Training Loss: 0.003                 	 Validation Loss: 0.827                 	 F1 score: 0.7891053292393371


100%|██████████| 46/46 [00:29<00:00,  1.54it/s]
100%|██████████| 12/12 [00:02<00:00,  4.96it/s]


Epoch 39/40                 	 Training Loss: 0.003                 	 Validation Loss: 0.859                 	 F1 score: 0.7897627416520211


100%|██████████| 46/46 [00:29<00:00,  1.55it/s]
100%|██████████| 12/12 [00:02<00:00,  5.04it/s]


Epoch 40/40                 	 Training Loss: 0.003                 	 Validation Loss: 0.900                 	 F1 score: 0.7947547412604553


### Predictions

In [None]:
first_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
first_model.load_state_dict(torch.load('first_model.pt'))

In [None]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_am.load_state_dict(torch.load('model_am.pt'))

In [None]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_ac.load_state_dict(torch.load('model_ac.pt'))

In [None]:
# Load best model

custom_model_2 = CustomBERTKuri(first_model, model_am, model_ac, 2)
custom_model_2.load_state_dict(torch.load('best_model.pt'))

custom_model_2.to(device).eval()

In [55]:
def predict(model, test_dataloader=None):
    
    """Prediction loop"""

    preds_l = []
    labels_l = []
    
    model.eval()

    for batch in test_dataloader:            
            
        # unpack batch             
        labels = batch['label'].to(device).flatten().tolist()
        spans = batch['spans'].to(device)
        input_ids = batch['input_ids'].to(device)
        
        inputs = input_ids, spans

        # get output
        
        raw_preds = model(inputs).to('cpu')
        # print(raw_preds.shape)
        raw_preds = raw_preds.detach()#.numpy()

        # Compute argmax
        
        predictions = torch.argmax(raw_preds, dim=2).flatten().tolist()
        preds_l.append(predictions)
        labels_l.append(labels)        
        
        del batch
            
    return flatten_list(preds_l), flatten_list(labels_l)

In [56]:
test_preds, test_labels = predict(custom_model_2, test_dataloader)

In [57]:
print(classification_report(test_labels, test_preds, digits=3))

              precision    recall  f1-score   support

        -100      0.000     0.000     0.000      3033
           0      0.107     0.728     0.187       518
           1      0.809     0.852     0.830       745

    accuracy                          0.236      4296
   macro avg      0.305     0.527     0.339      4296
weighted avg      0.153     0.236     0.167      4296



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [58]:
test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [59]:
print(classification_report(test_labels_l, test_preds_l, digits=3))

              precision    recall  f1-score   support

           0      0.774     0.728     0.750       518
           1      0.818     0.852     0.835       745

    accuracy                          0.801      1263
   macro avg      0.796     0.790     0.793      1263
weighted avg      0.800     0.801     0.800      1263

