# Implementation of the Kuribayashi BERT minus model

## libraries

In [1]:
!pip install transformers --upgrade
!pip install ipywidgets
!pip install IProgress
!pip install datasets
!pip install torch-lr-finder

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 21.4 MB/s eta 0:00:01
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 61.0 MB/s eta 0:00:01
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.10.3
    Uninstalling tokenizers-0.10.3:
      Successfully uninstalled tokenizers-0.10.3
  Attempting uninstall: transformers
    Found existing installation: transformers 4.12.5
    Uninstalling transformers-4.12.5:
      Successfully uninstalled transformers-4.12.5
Successfully installed tokenizers-0.13.2 transformers-4.25.1
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in index

In [2]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding, default_data_collator, DataCollatorWithPadding

import torch
import torch.nn as nn

import numpy as np

from sklearn.metrics import classification_report
from sklearn.metrics import f1_score

import datasets

from torch.utils.data import DataLoader

from tqdm import tqdm
from operator import itemgetter

In [3]:
print(transformers.__version__)

4.25.1


## tokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [5]:
# tokenizer.model_max_length = 256

## data

In [6]:
DATA_FILE = '/notebooks/KURI-BERT/notebooks/full_formula_w_fts/CLS_work/pe_dataset_for_component_features_implementation.pt'

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Load data

In [8]:
dataset = torch.load(DATA_FILE)

In [9]:
dataset

DatasetDict({
    train: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'structural_fts_as_text_combined', 'paragraph_labels_numeric', 'full_paragraph_w_comp_fts', 'new_full_para', 'component_indices', 'feature_indices', 'sanity_1'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'structural_fts_as_text_combined', 'paragraph_labels_numeric', 'full_paragraph_w_comp_fts', 'new_full_para', 'component_indices', 'feature_indices', 'sanity_1'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'structural_fts_as_text_combined', 'paragraph_labels_numeric', 'full_paragraph_w_comp_fts', 'new_full_para', 'component_indices', 'featur

In [10]:
dataset['train']['new_full_para'][498]

'component Computers have revolutionized our lifestyle during last century. However, some people especially young ones use computers extremely and this excessive usage lead to increase stress among them, but computers have brought many advantages to our lives. From my point of view, computers have an important and inevitable role in economic, business, and other part of societies. features 1, Yes, No, Yes, No component In addition, they make the way of living easier by doing our routine duties. features 1, No, Yes, Yes, No'

In [11]:
dataset['train']['component_indices'][498]

[0, 76]

In [12]:
dataset['train']['feature_indices'][498]

[66, 93]

In [15]:
for i,x in enumerate(tokenizer.tokenize(dataset['train']['new_full_para'][498])):
    print(i, x)

0 component
1 computers
2 have
3 revolution
4 ##ized
5 our
6 lifestyle
7 during
8 last
9 century
10 .
11 however
12 ,
13 some
14 people
15 especially
16 young
17 ones
18 use
19 computers
20 extremely
21 and
22 this
23 excessive
24 usage
25 lead
26 to
27 increase
28 stress
29 among
30 them
31 ,
32 but
33 computers
34 have
35 brought
36 many
37 advantages
38 to
39 our
40 lives
41 .
42 from
43 my
44 point
45 of
46 view
47 ,
48 computers
49 have
50 an
51 important
52 and
53 inevitable
54 role
55 in
56 economic
57 ,
58 business
59 ,
60 and
61 other
62 part
63 of
64 societies
65 .
66 features
67 1
68 ,
69 yes
70 ,
71 no
72 ,
73 yes
74 ,
75 no
76 component
77 in
78 addition
79 ,
80 they
81 make
82 the
83 way
84 of
85 living
86 easier
87 by
88 doing
89 our
90 routine
91 duties
92 .
93 features
94 1
95 ,
96 no
97 ,
98 yes
99 ,
100 yes
101 ,
102 no


### preprocessing

In [16]:
MAX_LENGTH = 0
MAX_CLS_IDX = 0

for split in ['train', 'test', 'validation']:
    
    for col_name in ['component_indices', 'feature_indices']:
        
        for x in dataset[split][col_name]:
            
            if max(x) > MAX_CLS_IDX:
                
                MAX_CLS_IDX = max(x)
                MAX_CLS_IDX = min(MAX_CLS_IDX, tokenizer.model_max_length - 2)
            
            if len(x) > MAX_LENGTH:
                
                MAX_LENGTH = len(x)

In [17]:
MAX_LENGTH

12

In [18]:
MAX_CLS_IDX

304

In [19]:
def get_padding(batch, padding_target):    
    
    if padding_target == 'adu_cls_indices':
        
        col_name = 'component_indices'
        padding_val = [-1]
        max_length = MAX_LENGTH
        
        
    elif padding_target == 'features_cls_indices':
        
        col_name = 'feature_indices'
        padding_val = [-1]
        max_length = MAX_LENGTH
        
    elif padding_target == 'label':
    
        col_name = 'paragraph_labels_numeric'
        padding_val = [-100]      
        max_length = MAX_LENGTH

    padded_indices_ll = []

    for i, indices_l in enumerate(batch[col_name]):

        padded_indices_l = indices_l + (max_length - len(indices_l)) * padding_val
        padded_indices_ll.append(padded_indices_l)

    return padded_indices_ll        

In [20]:
def get_cls_indices(adu_cls_indices_ll, features_cls_indices_ll):
    
    
    indices_ll = []
    
    for adu_cls_indices_l, features_cls_indices_l in zip(adu_cls_indices_ll, features_cls_indices_ll):


        indices_l = [[i,j] for i, j in zip(adu_cls_indices_l, features_cls_indices_l)]        
        indices_ll.append(indices_l)

    return indices_ll

### tokenize 

In [21]:
def tokenize(batch):
    
    tokenized_text = tokenizer(batch['new_full_para'], truncation=True, padding=True, max_length=512)
    tokenized_text['label'] = get_padding(batch, 'label')
    tokenized_text['adu_cls_indices'] = get_padding(batch, 'adu_cls_indices')
    tokenized_text['features_cls_indices'] = get_padding(batch, 'features_cls_indices')
    #tokenized_text['fts_spans'] = get_padding(batch, 'fts_spans')
    tokenized_text['cls_indices'] = get_cls_indices(tokenized_text['adu_cls_indices'], tokenized_text['features_cls_indices'])      
    
    return tokenized_text

In [22]:
dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset['train']))



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [23]:
dataset

DatasetDict({
    train: Dataset({
        features: ['adu_cls_indices', 'attention_mask', 'cls_indices', 'component_indices', 'essay_nr', 'feature_indices', 'features_cls_indices', 'full_paragraph_w_comp_fts', 'input_ids', 'label', 'new_full_para', 'paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_labels_numeric', 'paragraph_markers_list', 'sanity_1', 'split', 'structural_fts_as_text_combined', 'token_type_ids'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['adu_cls_indices', 'attention_mask', 'cls_indices', 'component_indices', 'essay_nr', 'feature_indices', 'features_cls_indices', 'full_paragraph_w_comp_fts', 'input_ids', 'label', 'new_full_para', 'paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_labels_numeric', 'paragraph_markers_list', 'sanity_1', 'split', 'structural_fts_as_text_combined', 'token_type_ids'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['adu_cls_indices', 

In [24]:
dataset['test'].features['cls_indices'] = datasets.Array2D(shape=(12, 2), dtype="int32")
dataset['train'].features['cls_indices'] = datasets.Array2D(shape=(12, 2), dtype="int32")
dataset['validation'].features['cls_indices'] = datasets.Array2D(shape=(12, 2), dtype="int32")

In [25]:
dataset = dataset.map(lambda batch: batch, batched=True)

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [26]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'cls_indices', 'label'])

In [27]:
dataset['train']['cls_indices'][498]

tensor([[ 0, 66],
        [76, 93],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1],
        [-1, -1]])

In [28]:
### Work

### Get outputs at CLS indices for both ADU and Features

In [29]:
def get_cls_outputs(outputs, cls_indices):
    
    batch_size = cls_indices.shape[0]    
    nr_adus = cls_indices.shape[1]
    
    cls_adu_indices = cls_indices[:,:,0]
    cls_features_indices = cls_indices[:,:,1]
    
    outputs_adu_cls_indices = outputs[:,cls_adu_indices,:]
    outputs_features_cls_indices = outputs[:,cls_features_indices,:]
    
    outputs_adus = torch.cat([outputs_adu_cls_indices[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_adus = outputs_adus.reshape(batch_size, nr_adus, -1)
    
    outputs_features = torch.cat([outputs_features_cls_indices[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_features = outputs_features.reshape(batch_size, nr_adus, -1)
    
    return outputs_adus, outputs_features

## span representation function

In [30]:
# def minus_one(t):
    
#     return torch.where(t == 0, 0, t-1)

# def plus_one(t):
    
#     return torch.where(t >= MAX_SPAN, MAX_SPAN, t+1)

## custom BERT model

In [31]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_adu, model_fts, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
#         self.intermediate_linear_am = nn.Linear(3072, 768)
#         self.intermediate_linear_ac = nn.Linear(3072, 768) 
#         self.intermediate_linear_fts = nn.Linear(3072, 768)
        
        self.model_adu = model_adu
        self.model_fts = model_fts
        
        self.nr_classes = nr_classes
                
        # self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size + self.model_fts.config.hidden_size, self.nr_classes)
        self.fc = nn.Linear(self.model_adu.config.hidden_size + self.model_fts.config.hidden_size, self.nr_classes)

    def forward(self, inputs):
        
        batch_tokenized, batch_cls_indices = inputs 
        # outputs = self.first_model(batch_tokenized, output_hidden_states=True)[1][12]
        outputs = self.first_model(batch_tokenized)[0]
        adu_cls_outputs, fts_cls_outputs = get_cls_outputs(outputs, batch_cls_indices)


#         am_minus_representations = self.intermediate_linear_am(am_minus_representations)
#         ac_minus_representations = self.intermediate_linear_ac(ac_minus_representations)
#         fts_minus_representations = self.intermediate_linear_fts(fts_minus_representations)


        output_model_adu = self.model_adu(inputs_embeds = adu_cls_outputs)[0]
        output_model_fts = self.model_fts(inputs_embeds = fts_cls_outputs)[0]
        
        #output_model_fts = self.model_fts(inputs_embeds = fts_minus_representations)[0]

        adu_representations = torch.cat([output_model_adu, output_model_fts], dim=-1)
        output = self.fc(adu_representations)

        return output

## Run

In [32]:
NB_EPOCHS = 40
BATCH_SIZE = 32

In [33]:
# first_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
# first_model.load_state_dict(torch.load('/notebooks/KURI-BERT/notebooks/full_formula_w_fts/icann_finetuned_work/finetuned_model.pth'))

In [34]:
first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [35]:
model_adu = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [36]:
model_fts = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [37]:
custom_model = CustomBERTKuri(first_model, model_adu, model_fts, 3)

In [38]:
custom_model.to(device)

CustomBERTKuri(
  (first_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_aff

In [39]:
loss = nn.CrossEntropyLoss(ignore_index=- 100)

In [41]:
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=1.8738174228603844e-05)

In [42]:
# 1.8738174228603844e-05
# 0.008111308307896872
# best learning rate found by the whole leslie business
# new best LR found= 9.999999999999997e-06

In [43]:
NR_BATCHES = len(dataset['train']) / BATCH_SIZE
num_training_steps = NB_EPOCHS * NR_BATCHES
num_warmup_steps = int(0.2 * num_training_steps)

In [44]:
def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

In [45]:
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

### create dataloaders

In [46]:
train_dataloader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, shuffle=True)

In [47]:
def flatten_list(list_of_lists):
    return [x for sublist in list_of_lists for x in sublist]

In [48]:
def remove_dummy_labels(test_preds, test_labels):
    
    idxes = []
    test_labels_l = []
    for idx, val in enumerate(test_labels):
        if val != -100:
            idxes.append(idx)
            test_labels_l.append(val)
    
    test_preds_l = []
    for idx, val in enumerate(test_preds):
        for good_idx in idxes:
            if idx == good_idx:
                test_preds_l.append(val)
        
    return test_preds_l, test_labels_l

### training loop

In [49]:
from datasets import load_metric

In [50]:
import random
from torch_lr_finder import LRFinder

### LR Finder Leslie Smith 

## training 

In [51]:
def train(model, loss=None, optimizer=None, train_dataloader=None, val_dataloader=None, nb_epochs=20):
    """Training loop"""

    min_f1 = -torch.inf
    train_losses = []
    val_losses = []

    # Iterrate over epochs
    for e in range(nb_epochs):

        # Training
        train_loss = 0.0

        for i, batch in enumerate(tqdm(train_dataloader)):            
            
            #print(i)
            # unpack batch             
            labels = batch['label'].to(device)
            cls_indices = batch['cls_indices'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, cls_indices
            
            # Reset gradients to 0
            optimizer.zero_grad()

            # Forward Pass
            outputs = model(inputs)
            
            # Compute training loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            train_loss += current_loss.detach().item()

            # Compute gradients
            current_loss.backward()

            # Update weights
            optimizer.step()            
            
            del batch
        
        scheduler.step()
            
        
        # Validation
        val_loss = 0.0
        
        # torch.cuda.empty_cache()
        
        # Put model in eval mode
        model.eval()
        
        preds_l = []
        labels_l = []
        
        for batch in tqdm(val_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            cls_indices = batch['cls_indices'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, cls_indices
            
            # Forward Pass
            outputs = model(inputs)

            # Compute validation loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            val_loss += current_loss.detach().item()
            
            preds_for_f1 = torch.argmax(outputs, dim=2).flatten().tolist()
            labels_for_f1 = labels.flatten().tolist()
            
            preds_l.append(preds_for_f1)
            labels_l.append(labels_for_f1)
            
            del batch
        
        # Prints
        
        preds_l = flatten_list(preds_l)
        labels_l = flatten_list(labels_l)
        
        preds_l, labels_l = remove_dummy_labels(preds_l, labels_l)
        
        f1_score_epoch = f1_score(preds_l, labels_l, average='macro')        
        
        print(f"Epoch {e+1}/{nb_epochs} \
                \t Training Loss: {train_loss/len(train_dataloader):.3f} \
                \t Validation Loss: {val_loss/len(val_dataloader):.3f} \
                \t F1 score: {f1_score_epoch}")
        
        train_losses.append(train_loss/len(train_dataloader))
        val_losses.append(val_loss/len(val_dataloader))
        

        # Save model if val loss decreases
        if f1_score_epoch > min_f1:

            min_f1 = f1_score_epoch
            torch.save(model.first_model.state_dict(), 'first_model.pt')
            torch.save(model.model_adu.state_dict(), 'model_adu.pt')
            torch.save(model.model_fts.state_dict(), 'model_fts.pt')
            # torch.save(model.model_fts.state_dict(), 'model_fts.pt')
            torch.save(model.state_dict(), 'best_model.pt')
            
    return train_losses, val_losses

In [52]:
train_losses, val_losses = train(custom_model, loss, optimizer, train_dataloader, val_dataloader, NB_EPOCHS)

100%|██████████| 34/34 [00:38<00:00,  1.12s/it]
100%|██████████| 9/9 [00:02<00:00,  3.02it/s]


Epoch 1/40                 	 Training Loss: 1.088                 	 Validation Loss: 1.059                 	 F1 score: 0.15506100138265022


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 2/40                 	 Training Loss: 0.895                 	 Validation Loss: 0.835                 	 F1 score: 0.2613299449385853


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.98it/s]


Epoch 3/40                 	 Training Loss: 0.792                 	 Validation Loss: 0.762                 	 F1 score: 0.2708873562228989


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.95it/s]


Epoch 4/40                 	 Training Loss: 0.711                 	 Validation Loss: 0.741                 	 F1 score: 0.346091798763929


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 5/40                 	 Training Loss: 0.651                 	 Validation Loss: 0.758                 	 F1 score: 0.3939505648048362


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 6/40                 	 Training Loss: 0.610                 	 Validation Loss: 0.801                 	 F1 score: 0.37741802716422


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 7/40                 	 Training Loss: 0.581                 	 Validation Loss: 0.901                 	 F1 score: 0.3665925376822337


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 8/40                 	 Training Loss: 0.559                 	 Validation Loss: 0.971                 	 F1 score: 0.3439160590809222


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 9/40                 	 Training Loss: 0.537                 	 Validation Loss: 1.016                 	 F1 score: 0.3948702209598915


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 10/40                 	 Training Loss: 0.528                 	 Validation Loss: 1.108                 	 F1 score: 0.36439956748780844


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 11/40                 	 Training Loss: 0.516                 	 Validation Loss: 1.109                 	 F1 score: 0.42304687687212744


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 12/40                 	 Training Loss: 0.507                 	 Validation Loss: 1.178                 	 F1 score: 0.41683075289023247


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 13/40                 	 Training Loss: 0.496                 	 Validation Loss: 1.165                 	 F1 score: 0.42639575914628275


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 14/40                 	 Training Loss: 0.492                 	 Validation Loss: 1.082                 	 F1 score: 0.4276158998400428


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 15/40                 	 Training Loss: 0.478                 	 Validation Loss: 1.156                 	 F1 score: 0.42254572601717055


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 16/40                 	 Training Loss: 0.471                 	 Validation Loss: 1.125                 	 F1 score: 0.4224357140712341


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 17/40                 	 Training Loss: 0.455                 	 Validation Loss: 1.067                 	 F1 score: 0.4293876933720218


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 18/40                 	 Training Loss: 0.451                 	 Validation Loss: 1.081                 	 F1 score: 0.4285885190059124


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 19/40                 	 Training Loss: 0.438                 	 Validation Loss: 1.028                 	 F1 score: 0.4360038172877847


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 20/40                 	 Training Loss: 0.432                 	 Validation Loss: 0.995                 	 F1 score: 0.4366061971511668


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 21/40                 	 Training Loss: 0.430                 	 Validation Loss: 1.016                 	 F1 score: 0.4284485132694071


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 22/40                 	 Training Loss: 0.414                 	 Validation Loss: 0.983                 	 F1 score: 0.43246619635508526


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.95it/s]


Epoch 23/40                 	 Training Loss: 0.407                 	 Validation Loss: 0.972                 	 F1 score: 0.43106453747874945


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 24/40                 	 Training Loss: 0.402                 	 Validation Loss: 0.950                 	 F1 score: 0.4411581002980638


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 25/40                 	 Training Loss: 0.398                 	 Validation Loss: 0.976                 	 F1 score: 0.42448626433052133


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 26/40                 	 Training Loss: 0.386                 	 Validation Loss: 0.967                 	 F1 score: 0.43585278068036687


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 27/40                 	 Training Loss: 0.373                 	 Validation Loss: 0.922                 	 F1 score: 0.4329321991820521


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 28/40                 	 Training Loss: 0.363                 	 Validation Loss: 0.839                 	 F1 score: 0.43055511118221085


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 29/40                 	 Training Loss: 0.361                 	 Validation Loss: 0.904                 	 F1 score: 0.4276445468252998


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 30/40                 	 Training Loss: 0.368                 	 Validation Loss: 0.850                 	 F1 score: 0.43549834657478964


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 31/40                 	 Training Loss: 0.351                 	 Validation Loss: 0.835                 	 F1 score: 0.4202269656376127


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 32/40                 	 Training Loss: 0.331                 	 Validation Loss: 0.828                 	 F1 score: 0.43004897938887643


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.95it/s]


Epoch 33/40                 	 Training Loss: 0.329                 	 Validation Loss: 0.838                 	 F1 score: 0.4182216010580398


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 34/40                 	 Training Loss: 0.323                 	 Validation Loss: 0.830                 	 F1 score: 0.41633627545457275


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 35/40                 	 Training Loss: 0.316                 	 Validation Loss: 0.829                 	 F1 score: 0.4146324928777048


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.95it/s]


Epoch 36/40                 	 Training Loss: 0.305                 	 Validation Loss: 0.818                 	 F1 score: 0.4075412595589587


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.95it/s]


Epoch 37/40                 	 Training Loss: 0.303                 	 Validation Loss: 0.714                 	 F1 score: 0.4921233823910902


100%|██████████| 34/34 [00:37<00:00,  1.11s/it]
100%|██████████| 9/9 [00:03<00:00,  2.98it/s]


Epoch 38/40                 	 Training Loss: 0.295                 	 Validation Loss: 0.777                 	 F1 score: 0.45759106825589324


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.96it/s]


Epoch 39/40                 	 Training Loss: 0.285                 	 Validation Loss: 0.915                 	 F1 score: 0.3452729126647664


100%|██████████| 34/34 [00:37<00:00,  1.10s/it]
100%|██████████| 9/9 [00:03<00:00,  2.97it/s]


Epoch 40/40                 	 Training Loss: 0.279                 	 Validation Loss: 0.807                 	 F1 score: 0.4503360112975332


### Predictions

In [50]:
test_dataloader = DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=True)

In [51]:
# first_model =  BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
# first_model.load_state_dict(torch.load('first_model.pt'))

In [52]:
first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
first_model.load_state_dict(torch.load('first_model.pt'))

<All keys matched successfully>

In [53]:
model_adu = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_adu.load_state_dict(torch.load('model_adu.pt'))

<All keys matched successfully>

In [54]:
model_fts = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_fts.load_state_dict(torch.load('model_fts.pt'))

<All keys matched successfully>

In [55]:
# Load best model

custom_model_2 = CustomBERTKuri(first_model, model_adu, model_fts, 3)
custom_model_2.load_state_dict(torch.load('best_model.pt'))

custom_model_2.to(device).eval()

CustomBERTKuri(
  (first_model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_aff

In [56]:
def predict(model, test_dataloader=None):
    
    """Prediction loop"""

    preds_l = []
    labels_l = []
    
    model.eval()

    for batch in test_dataloader:            
            
        # unpack batch             
        labels = batch['label'].to(device).flatten().tolist()
        cls_indices = batch['cls_indices'].to(device)
        input_ids = batch['input_ids'].to(device)
        
        inputs = input_ids, cls_indices

        # get output
        
        raw_preds = model(inputs).to('cpu')
        # print(raw_preds.shape)
        raw_preds = raw_preds.detach()#.numpy()

        # Compute argmax
        
        predictions = torch.argmax(raw_preds, dim=2).flatten().tolist()
        preds_l.append(predictions)
        labels_l.append(labels)        
        
        del batch
            
    return flatten_list(preds_l), flatten_list(labels_l)

In [57]:
test_preds, test_labels = predict(custom_model_2, test_dataloader)

In [58]:
print(classification_report(test_labels, test_preds, digits=3))

              precision    recall  f1-score   support

        -100      0.000     0.000     0.000      3033
           0      0.000     0.000     0.000       155
           1      0.000     0.000     0.000       303
           2      0.187     1.000     0.316       805

    accuracy                          0.187      4296
   macro avg      0.047     0.250     0.079      4296
weighted avg      0.035     0.187     0.059      4296



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [59]:
test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [60]:
print(classification_report(test_labels_l, test_preds_l, digits=3))

              precision    recall  f1-score   support

           0      0.000     0.000     0.000       155
           1      0.000     0.000     0.000       303
           2      0.637     1.000     0.779       805

    accuracy                          0.637      1263
   macro avg      0.212     0.333     0.260      1263
weighted avg      0.406     0.637     0.496      1263



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
