# Implementation of the Kuribayashi BERT minus model

## libraries

In [1]:
!pip install transformers --upgrade
!pip install ipywidgets
!pip install IProgress
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 18.9 MB/s eta 0:00:01
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 2.1 MB/s eta 0:00:01
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.10.3
    Uninstalling tokenizers-0.10.3:
      Successfully uninstalled tokenizers-0.10.3
  Attempting uninstall: transformers
    Found existing installation: transformers 4.12.5
    Uninstalling transformers-4.12.5:
      Successfully uninstalled transformers-4.12.5
Successfully installed tokenizers-0.13.2 transformers-4.24.0
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexe

In [2]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding, default_data_collator, DataCollatorWithPadding

import torch
import torch.nn as nn

import numpy as np

from sklearn.metrics import classification_report
from sklearn.metrics import f1_score

import datasets

from torch.utils.data import DataLoader

from tqdm import tqdm

In [3]:
print(transformers.__version__)

4.24.0


## tokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## data

In [5]:
# DATA_FOLDER = '/notebooks/Data/bert_sequence_classification'
DATA_FILE = '/notebooks/KURI-BERT/data/pe_dataset_for_bert_minus.pt'
RESULTS_FOLDER = '/notebooks/KURI-BERT/results'

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
device

device(type='cuda')

## Load data

In [8]:
dataset = torch.load(DATA_FILE)

In [9]:
dataset

DatasetDict({
    train: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 273
    })
})

### preprocessing

In [10]:
MAX_LENGTH = 0

for split in ['train', 'test', 'validation']:
    
    for col_name in ['paragraph_am_spans', 'paragraph_ac_spans']:
        
        for x in dataset[split][col_name]:
        
            if len(x) > MAX_LENGTH:
                
                MAX_LENGTH = len(x)

In [11]:
def get_padding(batch, padding_target):    
    
    if padding_target == 'am_spans':
        
        col_name = 'paragraph_am_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'ac_spans':
        
        col_name = 'paragraph_ac_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'label':
    
        col_name = 'paragraph_labels'
        padding_val = [-100] # -1 previously       
        max_length = MAX_LENGTH # max([len(l) for l in batch[col_name]]) # cause some batch had 4 x 10

    padded_spans = []

    for idx, span in enumerate(batch[col_name]):

        padded_span = batch[col_name][idx] + (max_length - len(span)) * padding_val
        padded_spans.append(padded_span)

    return padded_spans         

In [12]:
def get_combined_spans(am_spans_ll, ac_spans_ll):
    
    spans_ll = []
    
    for am_spans, ac_spans in zip(am_spans_ll, ac_spans_ll):
        
        spans = []
        
        for am_span, ac_span in zip(am_spans, ac_spans):

            span = [am_span, ac_span]
            spans.extend(span)
            
        spans_ll.append(spans)

    return spans_ll

### tokenize 

In [13]:
# max_length = 200 for use in the max_length in the tokenizer so that the things are of equal dim.

In [14]:
def tokenize(batch):
    
    tokenized_text = tokenizer(batch['paragraph'], truncation=True, padding=True, max_length=512)
    tokenized_text['label'] = get_padding(batch, 'label')
    tokenized_text['am_spans'] = get_padding(batch, 'am_spans')
    tokenized_text['ac_spans'] = get_padding(batch, 'ac_spans')
    tokenized_text['spans'] = get_combined_spans(tokenized_text['am_spans'], tokenized_text['ac_spans'])      
    
    return tokenized_text

In [15]:
dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset['train']))



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_ids'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_ids'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_

In [17]:
dataset['train'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'paragraph': Value(dtype='string', id=None),
 'paragraph_ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_components_list': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_labels': S

In [18]:
dataset['test'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'paragraph': Value(dtype='string', id=None),
 'paragraph_ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_components_list': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_labels': S

In [19]:
dataset['test'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['train'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['validation'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")

In [20]:
dataset = dataset.map(lambda batch: batch, batched=True)

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [21]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'spans', 'label'])

## span representation function

In [22]:
def get_span_representations(outputs, spans):

    batch_size = spans.shape[0]
    nr_span_indices = spans.shape[1]
    
    idx_l_ams = range(0, nr_span_indices, 2) # [0,2,4,6 etc]
    idx_l_acs = range(1, nr_span_indices, 2) # [1,3,5,7 etc]
    
    am_spans = spans[:, idx_l_ams, :] + 1 # adds 1 to all span indices (both am and ac) to offset for the CLS token in the input_ids.
    ac_spans = spans[:, idx_l_acs, :] + 1
    
    am_spans = am_spans.flatten(start_dim=1)
    ac_spans = ac_spans.flatten(start_dim=1)
    
    #print("am spans:", am_spans.shape)
    #print("ac spans:", ac_spans.shape)
    
    outputs_am = outputs[:,am_spans,:]
    outputs_am = torch.cat([outputs_am[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am = outputs_am.reshape(batch_size, nr_span_indices, -1)
    
    #print("outputs am:", outputs_am.shape)
    
    outputs_am_r = torch.cat([outputs_am[:,i,:] - outputs_am[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_r = outputs_am_r.reshape(batch_size, -1, 768)
    
    outputs_am_l = torch.cat([outputs_am[:,i+1,:] - outputs_am[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_l = outputs_am_r.reshape(batch_size, -1, 768)
    
    am_minus_representations = torch.cat([outputs_am_r, outputs_am_l], dim=-1)
    
    #print("am_minus_representations:", am_minus_representations.shape)
    
    outputs_ac = outputs[:,ac_spans,:]
    outputs_ac = torch.cat([outputs_ac[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac = outputs_ac.reshape(batch_size, nr_span_indices, -1)
    
    #print("outputs ac:", outputs_ac.shape)
    
    outputs_ac_r = torch.cat([outputs_ac[:,i,:] - outputs_ac[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_r = outputs_ac_r.reshape(batch_size, -1, 768)
    
    outputs_ac_l = torch.cat([outputs_ac[:,i+1,:] - outputs_ac[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_l = outputs_ac_l.reshape(batch_size, -1, 768)
    
    ac_minus_representations = torch.cat([outputs_ac_r, outputs_ac_l], dim=-1)
    
    #print("ac_minus_representations:", ac_minus_representations.shape)
    
    return am_minus_representations, ac_minus_representations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

## custom BERT model

In [23]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_am, model_ac, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
        self.intermediate_linear_am = nn.Linear(1536, 768)
        self.intermediate_linear_ac = nn.Linear(1536, 768)        
        
        self.model_am = model_am
        self.model_ac = model_ac
        
        self.nr_classes = nr_classes
                
        self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size, self.nr_classes)        

    def forward(self, batch_tokenized, batch_spans):
        
        outputs = self.first_model(batch_tokenized)[0] # ** removed to correct error.
        # spans = batch_spans # remove this spans thing cause we are now giving it just the spans themselves.
        am_minus_representations, ac_minus_representations = get_span_representations(outputs, batch_spans)

        am_minus_representations = self.intermediate_linear_am(am_minus_representations)
        ac_minus_representations = self.intermediate_linear_am(ac_minus_representations)

        output_model_am = self.model_am(inputs_embeds = am_minus_representations)[0]
        output_model_ac = self.model_ac(inputs_embeds = ac_minus_representations)[0]

        adu_representations = torch.cat([output_model_am, output_model_ac], dim=-1)
        # print("adu rep:", adu_representations.shape)
        output = self.fc(adu_representations)
        # print("model class output avant reshape:", output.shape)
        # output = output.reshape(-1, self.nr_classes)
        # print("model class output apres:", output.shape)
        return output

## Run

In [24]:
first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [25]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [26]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [27]:
custom_model = CustomBERTKuri(first_model, model_am, model_ac, 3)

In [28]:
custom_model.to(device)

CustomBERTKuri(
  (first_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_aff

In [29]:
loss = nn.CrossEntropyLoss(ignore_index=- 100)

In [30]:
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=1e-6, eps=1e-06, weight_decay=0.0)

In [31]:
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer)

### create dataloaders

In [32]:
NB_EPOCHS = 40
BATCH_SIZE = 48

In [33]:
train_dataloader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=True)

In [34]:
def flatten_list(list_of_lists):
    return [x for sublist in list_of_lists for x in sublist]

In [35]:
def remove_dummy_labels(test_preds, test_labels):
    
    idxes = []
    test_labels_l = []
    for idx, val in enumerate(test_labels):
        if val != -100:
            idxes.append(idx)
            test_labels_l.append(val)
    
    test_preds_l = []
    for idx, val in enumerate(test_preds):
        for good_idx in idxes:
            if idx == good_idx:
                test_preds_l.append(val)
        
    return test_preds_l, test_labels_l

### training loop

In [36]:
from transformers import Trainer, TrainingArguments
from datasets import load_metric

In [37]:
import random

In [38]:
# def compute_loss(self, model, inputs, return_outputs=False):

#     labels = inputs["labels"]
#     labels = labels.flatten() # xxx.



#     outputs = model(inputs['input_ids'], inputs['spans'])
#     outputs = outputs.flatten(0,1) # xxx. for the 4 x 12 , 3 output of the main class.


#     loss_fct = nn.CrossEntropyLoss()#(weight=class_weights)
#     # loss = loss_fct(outputs, labels.flatten())
#     loss = loss_fct(outputs, labels) # xxx
#     #print("loss:", loss)
#     return (loss, outputs) if return_outputs else loss

In [39]:
# metric = load_metric('f1')

# def compute_metrics(eval_pred):

#     logits, labels = eval_pred

    
#     print("logits", logits.shape)    
#     predictions = np.argmax(logits, axis=-1)
    
#     print("preds:", predictions.shape)
    
#     return metric.compute(predictions=predictions, references=labels, average='macro')

In [40]:
def train(model, loss=None, optimizer=None, train_dataloader=None, val_dataloader=None, nb_epochs=20):
    """Training loop"""

    min_f1 = -torch.inf
    train_losses = []
    val_losses = []

    # Iterrate over epochs
    for e in range(nb_epochs):

        # Training
        train_loss = 0.0

        for batch in tqdm(train_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            # Reset gradients to 0
            optimizer.zero_grad()

            # Forward Pass
            outputs = model(input_ids, spans)
            
            # Compute training loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            train_loss += current_loss.detach().item()

            # Compute gradients
            current_loss.backward()

            # Update weights
            optimizer.step()            
            
            del batch
        
        scheduler.step()
            
        
        # Validation
        val_loss = 0.0

        # Put model in eval mode
        model.eval()
        
        preds_l = []
        labels_l = []
        
        for batch in tqdm(val_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            # Forward Pass
            outputs = model(input_ids, spans)

            # Compute validation loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            val_loss += current_loss.detach().item()
            
            preds_for_f1 = torch.argmax(outputs, dim=2).flatten().tolist()
            labels_for_f1 = labels.flatten().tolist()
            
            preds_l.append(preds_for_f1)
            labels_l.append(labels_for_f1)
            
            del batch
        
        # Prints
        
        preds_l = flatten_list(preds_l)
        labels_l = flatten_list(labels_l)
        
        preds_l, labels_l = remove_dummy_labels(preds_l, labels_l)
        
        f1_score_epoch = f1_score(preds_l, labels_l, average='macro')        
        
        print(f"Epoch {e+1}/{nb_epochs} \
                \t Training Loss: {train_loss/len(train_dataloader):.3f} \
                \t Validation Loss: {val_loss/len(val_dataloader):.3f} \
                \t F1 score: {f1_score_epoch}")
        
        train_losses.append(train_loss/len(train_dataloader))
        val_losses.append(val_loss/len(val_dataloader))
        

        # Save model if val loss decreases
        if f1_score_epoch > min_f1:

            min_f1 = f1_score_epoch
            torch.save(model.first_model.state_dict(), 'first_model.pt')
            torch.save(model.model_am.state_dict(), 'model_am.pt')
            torch.save(model.model_ac.state_dict(), 'model_ac.pt')
            torch.save(model.state_dict(), 'best_model.pt')
            
    return train_losses, val_losses

In [41]:
train_losses, val_losses = train(custom_model, loss, optimizer, train_dataloader, val_dataloader, NB_EPOCHS)

100%|██████████| 23/23 [00:29<00:00,  1.29s/it]
100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Epoch 1/40                 	 Training Loss: 0.936                 	 Validation Loss: 0.864                 	 F1 score: 0.2697703989480833


100%|██████████| 23/23 [00:26<00:00,  1.17s/it]
100%|██████████| 6/6 [00:02<00:00,  2.67it/s]


Epoch 2/40                 	 Training Loss: 0.815                 	 Validation Loss: 0.754                 	 F1 score: 0.4481508651565327


100%|██████████| 23/23 [00:26<00:00,  1.16s/it]
100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Epoch 3/40                 	 Training Loss: 0.712                 	 Validation Loss: 0.691                 	 F1 score: 0.5672779833180837


100%|██████████| 23/23 [00:26<00:00,  1.16s/it]
100%|██████████| 6/6 [00:02<00:00,  2.67it/s]


Epoch 4/40                 	 Training Loss: 0.651                 	 Validation Loss: 0.643                 	 F1 score: 0.6368813843684239


100%|██████████| 23/23 [00:26<00:00,  1.17s/it]
100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Epoch 5/40                 	 Training Loss: 0.601                 	 Validation Loss: 0.613                 	 F1 score: 0.6794031823604069


100%|██████████| 23/23 [00:26<00:00,  1.17s/it]
100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Epoch 6/40                 	 Training Loss: 0.554                 	 Validation Loss: 0.578                 	 F1 score: 0.6915457670762905


100%|██████████| 23/23 [00:27<00:00,  1.18s/it]
100%|██████████| 6/6 [00:02<00:00,  2.69it/s]


Epoch 7/40                 	 Training Loss: 0.516                 	 Validation Loss: 0.559                 	 F1 score: 0.6930550334388


100%|██████████| 23/23 [00:26<00:00,  1.16s/it]
100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Epoch 8/40                 	 Training Loss: 0.488                 	 Validation Loss: 0.549                 	 F1 score: 0.7102742994555634


100%|██████████| 23/23 [00:26<00:00,  1.17s/it]
100%|██████████| 6/6 [00:02<00:00,  2.59it/s]


Epoch 9/40                 	 Training Loss: 0.461                 	 Validation Loss: 0.532                 	 F1 score: 0.7101976708108563


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Epoch 10/40                 	 Training Loss: 0.435                 	 Validation Loss: 0.538                 	 F1 score: 0.710681861537227


100%|██████████| 23/23 [00:26<00:00,  1.17s/it]
100%|██████████| 6/6 [00:02<00:00,  2.66it/s]


Epoch 11/40                 	 Training Loss: 0.415                 	 Validation Loss: 0.524                 	 F1 score: 0.7118606524167026


100%|██████████| 23/23 [00:26<00:00,  1.17s/it]
100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Epoch 12/40                 	 Training Loss: 0.395                 	 Validation Loss: 0.531                 	 F1 score: 0.7291639392013912


100%|██████████| 23/23 [00:26<00:00,  1.17s/it]
100%|██████████| 6/6 [00:02<00:00,  2.70it/s]


Epoch 13/40                 	 Training Loss: 0.375                 	 Validation Loss: 0.521                 	 F1 score: 0.7092079996253103


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.63it/s]


Epoch 14/40                 	 Training Loss: 0.353                 	 Validation Loss: 0.530                 	 F1 score: 0.7155464070256912


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.54it/s]


Epoch 15/40                 	 Training Loss: 0.334                 	 Validation Loss: 0.520                 	 F1 score: 0.7130015364248964


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.63it/s]


Epoch 16/40                 	 Training Loss: 0.316                 	 Validation Loss: 0.533                 	 F1 score: 0.7101275903976579


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Epoch 17/40                 	 Training Loss: 0.296                 	 Validation Loss: 0.533                 	 F1 score: 0.7132311202210371


100%|██████████| 23/23 [00:27<00:00,  1.18s/it]
100%|██████████| 6/6 [00:02<00:00,  2.60it/s]


Epoch 18/40                 	 Training Loss: 0.277                 	 Validation Loss: 0.536                 	 F1 score: 0.7069549931501702


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.61it/s]


Epoch 19/40                 	 Training Loss: 0.263                 	 Validation Loss: 0.543                 	 F1 score: 0.711213011494082


100%|██████████| 23/23 [00:27<00:00,  1.18s/it]
100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Epoch 20/40                 	 Training Loss: 0.241                 	 Validation Loss: 0.545                 	 F1 score: 0.7042232043803521


100%|██████████| 23/23 [00:27<00:00,  1.20s/it]
100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Epoch 21/40                 	 Training Loss: 0.224                 	 Validation Loss: 0.554                 	 F1 score: 0.7095019552566252


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.64it/s]


Epoch 22/40                 	 Training Loss: 0.206                 	 Validation Loss: 0.567                 	 F1 score: 0.7032036068067482


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.64it/s]


Epoch 23/40                 	 Training Loss: 0.190                 	 Validation Loss: 0.578                 	 F1 score: 0.7090917286454811


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.66it/s]


Epoch 24/40                 	 Training Loss: 0.173                 	 Validation Loss: 0.579                 	 F1 score: 0.7022074950611823


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Epoch 25/40                 	 Training Loss: 0.156                 	 Validation Loss: 0.587                 	 F1 score: 0.705280533152045


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.61it/s]


Epoch 26/40                 	 Training Loss: 0.141                 	 Validation Loss: 0.616                 	 F1 score: 0.7063044417710348


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Epoch 27/40                 	 Training Loss: 0.126                 	 Validation Loss: 0.623                 	 F1 score: 0.7093435239154332


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Epoch 28/40                 	 Training Loss: 0.112                 	 Validation Loss: 0.640                 	 F1 score: 0.6975343491386418


100%|██████████| 23/23 [00:27<00:00,  1.18s/it]
100%|██████████| 6/6 [00:02<00:00,  2.68it/s]


Epoch 29/40                 	 Training Loss: 0.099                 	 Validation Loss: 0.666                 	 F1 score: 0.6910969842147289


100%|██████████| 23/23 [00:27<00:00,  1.20s/it]
100%|██████████| 6/6 [00:02<00:00,  2.63it/s]


Epoch 30/40                 	 Training Loss: 0.088                 	 Validation Loss: 0.671                 	 F1 score: 0.7023371741056197


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.65it/s]


Epoch 31/40                 	 Training Loss: 0.077                 	 Validation Loss: 0.701                 	 F1 score: 0.700093097767342


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.68it/s]


Epoch 32/40                 	 Training Loss: 0.068                 	 Validation Loss: 0.694                 	 F1 score: 0.7021329708459355


100%|██████████| 23/23 [00:27<00:00,  1.20s/it]
100%|██████████| 6/6 [00:02<00:00,  2.62it/s]


Epoch 33/40                 	 Training Loss: 0.059                 	 Validation Loss: 0.723                 	 F1 score: 0.7063936896511643


100%|██████████| 23/23 [00:27<00:00,  1.20s/it]
100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Epoch 34/40                 	 Training Loss: 0.052                 	 Validation Loss: 0.725                 	 F1 score: 0.701228148614943


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.63it/s]


Epoch 35/40                 	 Training Loss: 0.046                 	 Validation Loss: 0.756                 	 F1 score: 0.6974140329210067


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.58it/s]


Epoch 36/40                 	 Training Loss: 0.041                 	 Validation Loss: 0.767                 	 F1 score: 0.702998677667131


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.61it/s]


Epoch 37/40                 	 Training Loss: 0.036                 	 Validation Loss: 0.792                 	 F1 score: 0.6999218176514411


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.57it/s]


Epoch 38/40                 	 Training Loss: 0.032                 	 Validation Loss: 0.789                 	 F1 score: 0.6926683750858432


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.64it/s]


Epoch 39/40                 	 Training Loss: 0.029                 	 Validation Loss: 0.813                 	 F1 score: 0.6984287116899117


100%|██████████| 23/23 [00:27<00:00,  1.19s/it]
100%|██████████| 6/6 [00:02<00:00,  2.68it/s]

Epoch 40/40                 	 Training Loss: 0.026                 	 Validation Loss: 0.830                 	 F1 score: 0.6954972011193766





### Predictions

In [42]:
# # Load best model
# network_2 = Network()

# network_2.load_state_dict(torch.load(path))
# network_2.eval()

In [43]:
first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
first_model.load_state_dict(torch.load('first_model.pt'))

<All keys matched successfully>

In [44]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_am.load_state_dict(torch.load('model_am.pt'))

<All keys matched successfully>

In [45]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_ac.load_state_dict(torch.load('model_ac.pt'))

<All keys matched successfully>

In [46]:
# Load best model

custom_model_2 = CustomBERTKuri(first_model, model_am, model_ac, 3)
custom_model_2.load_state_dict(torch.load('best_model.pt'))

custom_model_2.to(device).eval()

CustomBERTKuri(
  (first_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_aff

In [47]:
def predict(model, test_dataloader=None):
    
    """Prediction loop"""

    preds_l = []
    labels_l = []
    
    model.eval()

    for batch in test_dataloader:            
            
        # unpack batch             
        labels = batch['label'].to(device).flatten().tolist()
        spans = batch['spans'].to(device)
        input_ids = batch['input_ids'].to(device)

        # get output
        
        raw_preds = model(input_ids, spans).to('cpu')
        # print(raw_preds.shape)
        raw_preds = raw_preds.detach()#.numpy()

        # Compute argmax
        
        predictions = torch.argmax(raw_preds, dim=2).flatten().tolist()
        preds_l.append(predictions)
        labels_l.append(labels)        
        
        del batch
            
    return flatten_list(preds_l), flatten_list(labels_l)

In [48]:
#test_preds, test_labels = predict(custom_model, test_dataloader)

In [49]:
#print(classification_report(test_labels, test_preds, digits=3))

In [50]:
# remove -100s

In [51]:
#test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [52]:
#print(classification_report(test_labels_l, test_preds_l, digits=3))

In [53]:
### -100 done!

In [54]:
test_preds, test_labels = predict(custom_model_2, test_dataloader)

In [55]:
print(classification_report(test_labels, test_preds, digits=3))

              precision    recall  f1-score   support

        -100      0.000     0.000     0.000      3033
           0      0.507     0.729     0.598       155
           1      0.111     0.571     0.186       303
           2      0.286     0.892     0.433       805

    accuracy                          0.234      4296
   macro avg      0.226     0.548     0.304      4296
weighted avg      0.080     0.234     0.116      4296



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [56]:
test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [57]:
print(classification_report(test_labels_l, test_preds_l, digits=3))

              precision    recall  f1-score   support

           0      0.807     0.729     0.766       155
           1      0.601     0.571     0.585       303
           2      0.860     0.892     0.876       805

    accuracy                          0.795      1263
   macro avg      0.756     0.731     0.742      1263
weighted avg      0.791     0.795     0.793      1263

