# Implementation of the Kuribayashi BERT minus model

## libraries

In [1]:
!pip install transformers --upgrade
!pip install ipywidgets
!pip install IProgress
!pip install datasets
!pip install torch-lr-finder

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding, default_data_collator, DataCollatorWithPadding

import torch
import torch.nn as nn

import numpy as np

from sklearn.metrics import classification_report
from sklearn.metrics import f1_score

import datasets

from torch.utils.data import DataLoader

from tqdm import tqdm

In [3]:
print(transformers.__version__)

4.25.1


## tokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## data

In [5]:
# DATA_FOLDER = '/notebooks/Data/bert_sequence_classification'
DATA_FILE = '/notebooks/KURI-BERT/data/pe_dataset_for_bert_minus.pt'
RESULTS_FOLDER = '/notebooks/KURI-BERT/results'

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
device

device(type='cuda')

## Load data

In [8]:
dataset = torch.load(DATA_FILE)

In [10]:
dataset

DatasetDict({
    train: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 273
    })
})

In [12]:
dataset['train']['paragraph'][0]

'Firstly, by paying taxes for public school, affluent people effectively contribute to narrowing down the gap between rich and poor. It is true that many poor families are not able to afford tuition fees for their kids to attend a course. With the the tax amount for which they pay, the rich may help a vast number of students from families with poverty background to continue their studies, and earn a better quality of life.'

### preprocessing

In [10]:
MAX_LENGTH = 0

for split in ['train', 'test', 'validation']:
    
    for col_name in ['paragraph_am_spans', 'paragraph_ac_spans']:
        
        for x in dataset[split][col_name]:
        
            if len(x) > MAX_LENGTH:
                
                MAX_LENGTH = len(x)

In [11]:
def get_padding(batch, padding_target):    
    
    if padding_target == 'am_spans':
        
        col_name = 'paragraph_am_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'ac_spans':
        
        col_name = 'paragraph_ac_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'label':
    
        col_name = 'paragraph_labels'
        padding_val = [-100] # -1 previously       
        max_length = MAX_LENGTH # max([len(l) for l in batch[col_name]]) # cause some batch had 4 x 10

    padded_spans = []

    for idx, span in enumerate(batch[col_name]):

        padded_span = batch[col_name][idx] + (max_length - len(span)) * padding_val
        padded_spans.append(padded_span)

    return padded_spans         

In [12]:
def get_combined_spans(am_spans_ll, ac_spans_ll):
    
    spans_ll = []
    
    for am_spans, ac_spans in zip(am_spans_ll, ac_spans_ll):
        
        spans = []
        
        for am_span, ac_span in zip(am_spans, ac_spans):

            span = [am_span, ac_span]
            spans.extend(span)
            
        spans_ll.append(spans)

    return spans_ll

### tokenize 

In [13]:
# max_length = 200 for use in the max_length in the tokenizer so that the things are of equal dim.

In [14]:
def tokenize(batch):
    
    tokenized_text = tokenizer(batch['paragraph'], truncation=True, padding=True, max_length=512)
    tokenized_text['label'] = get_padding(batch, 'label')
    tokenized_text['am_spans'] = get_padding(batch, 'am_spans')
    tokenized_text['ac_spans'] = get_padding(batch, 'ac_spans')
    tokenized_text['spans'] = get_combined_spans(tokenized_text['am_spans'], tokenized_text['ac_spans'])      
    
    return tokenized_text

In [15]:
dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset['train']))



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_ids'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_ids'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_

In [17]:
dataset['train'][0]['spans']

[[0, 0],
 [2, 21],
 [-1, -1],
 [27, 43],
 [-1, -1],
 [45, 80],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1]]

In [18]:
dataset['train'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'paragraph': Value(dtype='string', id=None),
 'paragraph_ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_components_list': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_labels': S

In [19]:
dataset['test'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'paragraph': Value(dtype='string', id=None),
 'paragraph_ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_components_list': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_labels': S

In [20]:
dataset['test'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['train'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['validation'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")

In [21]:
dataset = dataset.map(lambda batch: batch, batched=True)

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [22]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'spans', 'label'])

In [23]:
def minus_one(t):
    
    return torch.where(t == 0, 0, t-1)

def plus_one(t):
    
    return torch.where(t == 0, 0, t+1)

## span representation function

In [24]:
def get_span_representations(outputs, spans):

    batch_size = spans.shape[0]
    nr_span_indices = spans.shape[1]
    
    #print('nr span indices: ', nr_span_indices)
    
    idx_l_ams = range(0, nr_span_indices, 2) # [0,2,4,6 etc]
    idx_l_acs = range(1, nr_span_indices, 2) # [1,3,5,7 etc]
    
    am_spans = spans[:, idx_l_ams, :] + 1 # adds 1 to all span indices (both am and ac) to offset for the CLS token in the input_ids.
    ac_spans = spans[:, idx_l_acs, :] + 1
    
    # fix for [-1, +1] problem
    
    am_spans_minus_one = minus_one(am_spans) # xxx. added to solve bug 2
    am_spans_plus_one = plus_one(am_spans) # xxx. added to solve bug 2
    
    ac_spans_minus_one = minus_one(ac_spans) # xxx. added to solve bug 2
    ac_spans_plus_one = plus_one(ac_spans) # xxx. added to solve bug 2
    
    
    am_spans = am_spans.flatten(start_dim=1)
    ac_spans = ac_spans.flatten(start_dim=1)
    
    am_spans_minus_one = am_spans_minus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    am_spans_plus_one = am_spans_plus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    
    ac_spans_minus_one = ac_spans_minus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    ac_spans_plus_one = ac_spans_plus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    
    #print("am spans:", am_spans.shape)
    #print("ac spans:", ac_spans.shape)
    
    outputs_am = outputs[:,am_spans,:]
    #print("outputs am before cat:", outputs_am.shape)
    outputs_am = torch.cat([outputs_am[i,i,:,:] for i in range(batch_size)], dim=0)
    #print("outputs am after cat:", outputs_am.shape)
    outputs_am = outputs_am.reshape(batch_size, nr_span_indices, -1)
    
    # print("outputs am:", outputs_am.shape)
    
    
    
    outputs_am_minus_one = outputs[:,am_spans_minus_one,:] # xxx. added to solve bug 2
    outputs_am_minus_one = torch.cat([outputs_am_minus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_am_minus_one = outputs_am_minus_one.reshape(batch_size, nr_span_indices, -1) # xxx. added to solve bug 2
    
    outputs_am_plus_one = outputs[:,am_spans_plus_one,:] # xxx. added to solve bug 2
    outputs_am_plus_one = torch.cat([outputs_am_plus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_am_plus_one = outputs_am_plus_one.reshape(batch_size, nr_span_indices, -1) # xxx. added to solve bug 2
    
    
    
    ### Now that we have outputs_am i.e. outputs at am_span indices, now create the four Kuri forumlas for AMs
    
    # ============== the corrected 1st one =================== 
    
    outputs_am_first_term = torch.cat([outputs_am[:,i+1,:] - outputs_am_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1) # i + 1 here means j in kuri, i here means i in kuri
    outputs_am_first_term = outputs_am_first_term.reshape(batch_size, -1, 768)
    
#     # ============== FIRST TERM =================== # the one which gave us .788
    
#     outputs_am_first_term = torch.cat([outputs_am[:,i+1,:] - outputs_am[:,i-1,:] for i in range(0, nr_span_indices, 2)], dim=1) # j - (i - 1)
#     outputs_am_first_term = outputs_am_first_term.reshape(batch_size, -1, 768)
    
#     # ============== SECOND TERM ================== # the one which gave us .788
    
#     outputs_am_second_term = torch.cat([outputs_am[:,i,:] - outputs_am[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
#     outputs_am_second_term = outputs_am_second_term.reshape(batch_size, -1, 768)
    
     # ============== the corrected 2nd one ==================
    
    outputs_am_second_term = torch.cat([outputs_am[:,i,:] - outputs_am_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
    outputs_am_second_term = outputs_am_second_term.reshape(batch_size, -1, 768)
    
    
#     # ============== THIRD TERM ================== # the one for .788
    
#     outputs_am_third_term = torch.cat([outputs_am[:,i-1,:] for i in range(0, nr_span_indices, 2)], dim=1)
#     outputs_am_third_term = outputs_am_third_term.reshape(batch_size, -1, 768)
    
        # ============== the corrected third one ================== 
    
    outputs_am_third_term = torch.cat([outputs_am_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_third_term = outputs_am_third_term.reshape(batch_size, -1, 768)
    
#     # ============== FOURTH TERM ================== # the one for .788
    
#     outputs_am_fourth_term = torch.cat([outputs_am[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
#     outputs_am_fourth_term = outputs_am_fourth_term.reshape(batch_size, -1, 768)

        # ============== the corrected fourth one ==================
    
    outputs_am_fourth_term = torch.cat([outputs_am_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
    outputs_am_fourth_term = outputs_am_fourth_term.reshape(batch_size, -1, 768)
    
    # ============== NOW CONCATENATE THEM =========
    
    
    am_minus_representations = torch.cat([outputs_am_first_term, outputs_am_second_term, outputs_am_third_term, outputs_am_fourth_term], dim=-1)   
    
    
    
    ### am minus span representation according to kuribayashi paper is now here.
    
    ### ========================================= OLD AM CALCULATIONS ==========================
    
#     outputs_am_r = torch.cat([outputs_am[:,i,:] - outputs_am[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
#     outputs_am_r = outputs_am_r.reshape(batch_size, -1, 768)
    
#     outputs_am_l = torch.cat([outputs_am[:,i+1,:] - outputs_am[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
#     outputs_am_l = outputs_am_r.reshape(batch_size, -1, 768)
    
#     am_minus_representations = torch.cat([outputs_am_r, outputs_am_l], dim=-1)
    
    #print("am_minus_representations:", am_minus_representations.shape)
    
    ### ====================================== FIN OLD AM CALCULATIONS =========================
    
    outputs_ac = outputs[:,ac_spans,:]
    outputs_ac = torch.cat([outputs_ac[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac = outputs_ac.reshape(batch_size, nr_span_indices, -1)
    
    #print("outputs ac:", outputs_ac.shape)
    
    outputs_ac_minus_one = outputs[:,ac_spans_minus_one,:] # xxx. added to solve bug 2
    outputs_ac_minus_one = torch.cat([outputs_ac_minus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_ac_minus_one = outputs_ac_minus_one.reshape(batch_size, nr_span_indices, -1) # xxx. added to solve bug 2
    
    outputs_ac_plus_one = outputs[:,ac_spans_plus_one,:] # xxx. added to solve bug 2
    outputs_ac_plus_one = torch.cat([outputs_ac_plus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_ac_plus_one = outputs_ac_plus_one.reshape(batch_size, nr_span_indices, -1) # xxx. added to solve bug 2
    
    ### Now that we have outputs_ac i.e. outputs at ac_span indices, now create the four Kuri forumlas for ACs
    
    
#     # ============== FIRST TERM =================== xxx. for .788
    
#     outputs_ac_first_term = torch.cat([outputs_ac[:,i+1,:] - outputs_ac[:,i-1,:] for i in range(0, nr_span_indices, 2)], dim=1)
#     outputs_ac_first_term = outputs_ac_first_term.reshape(batch_size, -1, 768)
    
    # ============== the corrected first one ===================
    
    outputs_ac_first_term = torch.cat([outputs_ac[:,i+1,:] - outputs_ac_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_first_term = outputs_ac_first_term.reshape(batch_size, -1, 768)
    
#     # ============== SECOND TERM ================== xxx. for .788
    
#     outputs_ac_second_term = torch.cat([outputs_ac[:,i,:] - outputs_ac[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run
#     outputs_ac_second_term = outputs_ac_second_term.reshape(batch_size, -1, 768)
    
        # ============== the correct second one ==================
    
    outputs_ac_second_term = torch.cat([outputs_ac[:,i,:] - outputs_ac_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run
    outputs_ac_second_term = outputs_ac_second_term.reshape(batch_size, -1, 768)
    
    
#     # ============== THIRD TERM ================== # the one for .788
    
#     outputs_ac_third_term = torch.cat([outputs_ac[:,i-1,:] for i in range(0, nr_span_indices, 2)], dim=1)
#     outputs_ac_third_term = outputs_ac_third_term.reshape(batch_size, -1, 768)
    
    # ============== the corrected third term ==================
    
    outputs_ac_third_term = torch.cat([outputs_ac_minus_one[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_third_term = outputs_ac_third_term.reshape(batch_size, -1, 768)
    
#     # ============== FOURTH TERM ================== xxx. for .788
    
#     outputs_ac_fourth_term = torch.cat([outputs_ac[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run
#     outputs_ac_fourth_term = outputs_ac_fourth_term.reshape(batch_size, -1, 768)
    
     # ============== the corrected fourth term ================== xxx. for .788
    
    outputs_ac_fourth_term = torch.cat([outputs_ac_plus_one[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1) # changed + 2 to + 1 to make it run
    outputs_ac_fourth_term = outputs_ac_fourth_term.reshape(batch_size, -1, 768)
    
    # ============== NOW CONCATENATE THEM =========
    
    
    ac_minus_representations = torch.cat([outputs_ac_first_term, outputs_ac_second_term, outputs_ac_third_term, outputs_ac_fourth_term], dim=-1)   
    
    
    
    ### ac minus span representation according to kuribayashi paper is now here.
    
    ### ========================================= OLD AC CALCULATIONS ==========================
    
#     outputs_ac_r = torch.cat([outputs_ac[:,i,:] - outputs_ac[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
#     outputs_ac_r = outputs_ac_r.reshape(batch_size, -1, 768)
    
#     outputs_ac_l = torch.cat([outputs_ac[:,i+1,:] - outputs_ac[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
#     outputs_ac_l = outputs_ac_l.reshape(batch_size, -1, 768)
    
#     ac_minus_representations = torch.cat([outputs_ac_r, outputs_ac_l], dim=-1)

    ### ====================================== FIN OLD AC CALCULATIONS =========================
    
    #print("ac_minus_representations:", ac_minus_representations.shape)
    
    return am_minus_representations, ac_minus_representations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

## custom BERT model

In [25]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_am, model_ac, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
        self.intermediate_linear_am = nn.Linear(768 * 4, 768)
        self.intermediate_linear_ac = nn.Linear(768 * 4, 768)        
        
        self.model_am = model_am
        self.model_ac = model_ac
        
        self.nr_classes = nr_classes
                
        self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size, self.nr_classes)        

    def forward(self, inputs):
        
        batch_tokenized, batch_spans = inputs         
        outputs = self.first_model(batch_tokenized, output_hidden_states=True)[1][12] 
        #print('first model output:', outputs)
        am_minus_representations, ac_minus_representations = get_span_representations(outputs, batch_spans)
        
        # print('am minus rep:', am_minus_representations.shape)
        # print('ac minus rep:', ac_minus_representations.shape)

        am_minus_representations = self.intermediate_linear_am(am_minus_representations)
        ac_minus_representations = self.intermediate_linear_ac(ac_minus_representations)

        output_model_am = self.model_am(inputs_embeds = am_minus_representations)[0]
        output_model_ac = self.model_ac(inputs_embeds = ac_minus_representations)[0]

        adu_representations = torch.cat([output_model_am, output_model_ac], dim=-1)
        # print("adu rep:", adu_representations.shape)
        output = self.fc(adu_representations)
        # print("model class output avant reshape:", output.shape)
        # output = output.reshape(-1, self.nr_classes)
        # print("model class output apres:", output.shape)
        return output

## Run

In [26]:
NB_EPOCHS = 40
BATCH_SIZE = 12

In [27]:
first_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
first_model.load_state_dict(torch.load('/notebooks/KURI-BERT/notebooks/full_formula_w_fts/icann_finetuned_work/finetuned_model.pth'))
#first_model = torch.load('/notebooks/KURI-BERT/notebooks/full_formula_w_fts/icann_finetuned_work/best-model-icann-finetune')

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

<All keys matched successfully>

In [28]:
# model_am = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
# model_am.load_state_dict(torch.load('/notebooks/KURI-BERT/notebooks/full_formula_w_fts/icann_finetuned_work/finetuned_model.pth'))
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [29]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [30]:
custom_model = CustomBERTKuri(first_model, model_am, model_ac, 3)

In [31]:
custom_model.to(device)

CustomBERTKuri(
  (first_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, b

In [32]:
loss = nn.CrossEntropyLoss(ignore_index=- 100)

In [33]:
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=1.8738174228603844e-05)

In [34]:
# 1.8738174228603844e-05
# 6.579332246575682e-06 used for .783 corrected formula, 16 batch size
# 8.111308307896876e-06
# correct formula lr 9.999999999999997e-06 wo |LR
# best learning rate found by the whole leslie business
# new best LR found= 9.999999999999997e-06

In [35]:
NR_BATCHES = len(dataset['train']) / BATCH_SIZE
num_training_steps = NB_EPOCHS * NR_BATCHES
num_warmup_steps = int(0.2 * num_training_steps)

In [36]:
def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    # return LambdaLR(optimizer, lr_lambda, last_epoch)

In [37]:
# commented for LR Finder. remove it from optimizer.
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

### create dataloaders

In [38]:
train_dataloader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=True)

In [39]:
def flatten_list(list_of_lists):
    return [x for sublist in list_of_lists for x in sublist]

In [40]:
def remove_dummy_labels(test_preds, test_labels):
    
    idxes = []
    test_labels_l = []
    for idx, val in enumerate(test_labels):
        if val != -100:
            idxes.append(idx)
            test_labels_l.append(val)
    
    test_preds_l = []
    for idx, val in enumerate(test_preds):
        for good_idx in idxes:
            if idx == good_idx:
                test_preds_l.append(val)
        
    return test_preds_l, test_labels_l

### training loop

In [41]:
from transformers import Trainer, TrainingArguments
from datasets import load_metric

In [42]:
import random
from torch_lr_finder import LRFinder

In [43]:
# def compute_loss(self, model, inputs, return_outputs=False):

#     labels = inputs["labels"]
#     labels = labels.flatten() # xxx.



#     outputs = model(inputs['input_ids'], inputs['spans'])
#     outputs = outputs.flatten(0,1) # xxx. for the 4 x 12 , 3 output of the main class.


#     loss_fct = nn.CrossEntropyLoss()#(weight=class_weights)
#     # loss = loss_fct(outputs, labels.flatten())
#     loss = loss_fct(outputs, labels) # xxx
#     #print("loss:", loss)
#     return (loss, outputs) if return_outputs else loss

In [44]:
# metric = load_metric('f1')

# def compute_metrics(eval_pred):

#     logits, labels = eval_pred

    
#     print("logits", logits.shape)    
#     predictions = np.argmax(logits, axis=-1)
    
#     print("preds:", predictions.shape)
    
#     return metric.compute(predictions=predictions, references=labels, average='macro')

### LR Finder Leslie Smith 

## training 

In [45]:
def train(model, loss=None, optimizer=None, train_dataloader=None, val_dataloader=None, nb_epochs=20):
    """Training loop"""

    min_f1 = -torch.inf
    train_losses = []
    val_losses = []

    # Iterrate over epochs
    for e in range(nb_epochs):

        # Training
        train_loss = 0.0

        for batch in tqdm(train_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, spans
            
            # Reset gradients to 0
            optimizer.zero_grad()

            # Forward Pass
            outputs = model(inputs)
            
            # Compute training loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            train_loss += current_loss.detach().item()

            # Compute gradients
            current_loss.backward()

            # Update weights
            optimizer.step()            
            
            del batch
        
        scheduler.step()
            
        
        # Validation
        val_loss = 0.0

        # Put model in eval mode
        model.eval()
        
        preds_l = []
        labels_l = []
        
        for batch in tqdm(val_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, spans
            
            # Forward Pass
            outputs = model(inputs)

            # Compute validation loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            val_loss += current_loss.detach().item()
            
            preds_for_f1 = torch.argmax(outputs, dim=2).flatten().tolist()
            labels_for_f1 = labels.flatten().tolist()
            
            preds_l.append(preds_for_f1)
            labels_l.append(labels_for_f1)
            
            del batch
        
        # Prints
        
        preds_l = flatten_list(preds_l)
        labels_l = flatten_list(labels_l)
        
        preds_l, labels_l = remove_dummy_labels(preds_l, labels_l)
        
        f1_score_epoch = f1_score(preds_l, labels_l, average='macro')        
        
        print(f"Epoch {e+1}/{nb_epochs} \
                \t Training Loss: {train_loss/len(train_dataloader):.3f} \
                \t Validation Loss: {val_loss/len(val_dataloader):.3f} \
                \t F1 score: {f1_score_epoch}")
        
        train_losses.append(train_loss/len(train_dataloader))
        val_losses.append(val_loss/len(val_dataloader))
        

        # Save model if val loss decreases
        if f1_score_epoch > min_f1:

            min_f1 = f1_score_epoch
            torch.save(model.first_model.state_dict(), 'first_model.pt')
            torch.save(model.model_am.state_dict(), 'model_am.pt')
            torch.save(model.model_ac.state_dict(), 'model_ac.pt')
            torch.save(model.state_dict(), 'best_model.pt')
            
    return train_losses, val_losses

In [46]:
train_losses, val_losses = train(custom_model, loss, optimizer, train_dataloader, val_dataloader, NB_EPOCHS)

100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.26it/s]


Epoch 1/40                 	 Training Loss: 1.213                 	 Validation Loss: 1.210                 	 F1 score: 0.20732932796177686


100%|██████████| 91/91 [00:34<00:00,  2.63it/s]
100%|██████████| 23/23 [00:02<00:00,  8.27it/s]


Epoch 2/40                 	 Training Loss: 1.064                 	 Validation Loss: 0.953                 	 F1 score: 0.3068584610191107


100%|██████████| 91/91 [00:35<00:00,  2.59it/s]
100%|██████████| 23/23 [00:02<00:00,  8.01it/s]


Epoch 3/40                 	 Training Loss: 0.903                 	 Validation Loss: 0.838                 	 F1 score: 0.2997195637505926


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  8.07it/s]


Epoch 4/40                 	 Training Loss: 0.814                 	 Validation Loss: 0.761                 	 F1 score: 0.39528522635686764


100%|██████████| 91/91 [00:34<00:00,  2.63it/s]
100%|██████████| 23/23 [00:03<00:00,  7.52it/s]


Epoch 5/40                 	 Training Loss: 0.734                 	 Validation Loss: 0.691                 	 F1 score: 0.5347103042215756


100%|██████████| 91/91 [00:34<00:00,  2.62it/s]
100%|██████████| 23/23 [00:03<00:00,  7.44it/s]


Epoch 6/40                 	 Training Loss: 0.678                 	 Validation Loss: 0.646                 	 F1 score: 0.5926090116262962


100%|██████████| 91/91 [00:34<00:00,  2.62it/s]
100%|██████████| 23/23 [00:03<00:00,  7.41it/s]


Epoch 7/40                 	 Training Loss: 0.627                 	 Validation Loss: 0.612                 	 F1 score: 0.6362801483713604


100%|██████████| 91/91 [00:34<00:00,  2.62it/s]
100%|██████████| 23/23 [00:03<00:00,  7.25it/s]


Epoch 8/40                 	 Training Loss: 0.593                 	 Validation Loss: 0.584                 	 F1 score: 0.642851765046064


100%|██████████| 91/91 [00:34<00:00,  2.63it/s]
100%|██████████| 23/23 [00:03<00:00,  7.67it/s]


Epoch 9/40                 	 Training Loss: 0.560                 	 Validation Loss: 0.553                 	 F1 score: 0.6940990979027789


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:03<00:00,  7.42it/s]


Epoch 10/40                 	 Training Loss: 0.534                 	 Validation Loss: 0.536                 	 F1 score: 0.7150257032144213


100%|██████████| 91/91 [00:34<00:00,  2.63it/s]
100%|██████████| 23/23 [00:03<00:00,  7.29it/s]


Epoch 11/40                 	 Training Loss: 0.513                 	 Validation Loss: 0.523                 	 F1 score: 0.7199898532279715


100%|██████████| 91/91 [00:34<00:00,  2.62it/s]
100%|██████████| 23/23 [00:03<00:00,  7.36it/s]


Epoch 12/40                 	 Training Loss: 0.490                 	 Validation Loss: 0.497                 	 F1 score: 0.7265434921049624


100%|██████████| 91/91 [00:34<00:00,  2.63it/s]
100%|██████████| 23/23 [00:03<00:00,  7.29it/s]


Epoch 13/40                 	 Training Loss: 0.471                 	 Validation Loss: 0.492                 	 F1 score: 0.7420584396040676


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:03<00:00,  7.40it/s]


Epoch 14/40                 	 Training Loss: 0.454                 	 Validation Loss: 0.475                 	 F1 score: 0.7391835393254053


100%|██████████| 91/91 [00:34<00:00,  2.63it/s]
100%|██████████| 23/23 [00:02<00:00,  8.15it/s]


Epoch 15/40                 	 Training Loss: 0.432                 	 Validation Loss: 0.473                 	 F1 score: 0.7639073450745606


100%|██████████| 91/91 [00:35<00:00,  2.60it/s]
100%|██████████| 23/23 [00:02<00:00,  8.16it/s]


Epoch 16/40                 	 Training Loss: 0.413                 	 Validation Loss: 0.465                 	 F1 score: 0.7791738463513687


100%|██████████| 91/91 [00:35<00:00,  2.58it/s]
100%|██████████| 23/23 [00:02<00:00,  8.10it/s]


Epoch 17/40                 	 Training Loss: 0.401                 	 Validation Loss: 0.442                 	 F1 score: 0.7739132908290377


100%|██████████| 91/91 [00:34<00:00,  2.66it/s]
100%|██████████| 23/23 [00:02<00:00,  7.86it/s]


Epoch 18/40                 	 Training Loss: 0.380                 	 Validation Loss: 0.447                 	 F1 score: 0.7520767211106697


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  8.16it/s]


Epoch 19/40                 	 Training Loss: 0.358                 	 Validation Loss: 0.433                 	 F1 score: 0.7904539830519127


100%|██████████| 91/91 [00:34<00:00,  2.62it/s]
100%|██████████| 23/23 [00:02<00:00,  7.68it/s]


Epoch 20/40                 	 Training Loss: 0.338                 	 Validation Loss: 0.432                 	 F1 score: 0.7908265207704198


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:03<00:00,  7.46it/s]


Epoch 21/40                 	 Training Loss: 0.323                 	 Validation Loss: 0.434                 	 F1 score: 0.7876651231279942


100%|██████████| 91/91 [00:34<00:00,  2.66it/s]
100%|██████████| 23/23 [00:02<00:00,  8.02it/s]


Epoch 22/40                 	 Training Loss: 0.304                 	 Validation Loss: 0.427                 	 F1 score: 0.7852553832068594


100%|██████████| 91/91 [00:34<00:00,  2.67it/s]
100%|██████████| 23/23 [00:02<00:00,  8.16it/s]


Epoch 23/40                 	 Training Loss: 0.281                 	 Validation Loss: 0.414                 	 F1 score: 0.7783211480916824


100%|██████████| 91/91 [00:34<00:00,  2.66it/s]
100%|██████████| 23/23 [00:02<00:00,  8.03it/s]


Epoch 24/40                 	 Training Loss: 0.262                 	 Validation Loss: 0.433                 	 F1 score: 0.7935038206425894


100%|██████████| 91/91 [00:34<00:00,  2.62it/s]
100%|██████████| 23/23 [00:03<00:00,  7.40it/s]


Epoch 25/40                 	 Training Loss: 0.244                 	 Validation Loss: 0.432                 	 F1 score: 0.7854418598052918


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.03it/s]


Epoch 26/40                 	 Training Loss: 0.220                 	 Validation Loss: 0.440                 	 F1 score: 0.7740474971423551


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.05it/s]


Epoch 27/40                 	 Training Loss: 0.200                 	 Validation Loss: 0.435                 	 F1 score: 0.7721055279104885


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.15it/s]


Epoch 28/40                 	 Training Loss: 0.180                 	 Validation Loss: 0.458                 	 F1 score: 0.7831198876029953


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.20it/s]


Epoch 29/40                 	 Training Loss: 0.159                 	 Validation Loss: 0.466                 	 F1 score: 0.7745431210645973


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  8.01it/s]


Epoch 30/40                 	 Training Loss: 0.135                 	 Validation Loss: 0.469                 	 F1 score: 0.7829230261672832


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.14it/s]


Epoch 31/40                 	 Training Loss: 0.116                 	 Validation Loss: 0.497                 	 F1 score: 0.7670895439989476


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  8.01it/s]


Epoch 32/40                 	 Training Loss: 0.095                 	 Validation Loss: 0.496                 	 F1 score: 0.7757765194228612


100%|██████████| 91/91 [00:34<00:00,  2.66it/s]
100%|██████████| 23/23 [00:02<00:00,  7.98it/s]


Epoch 33/40                 	 Training Loss: 0.078                 	 Validation Loss: 0.536                 	 F1 score: 0.7674114554782955


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  8.00it/s]


Epoch 34/40                 	 Training Loss: 0.066                 	 Validation Loss: 0.550                 	 F1 score: 0.7623812551783385


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  8.00it/s]


Epoch 35/40                 	 Training Loss: 0.051                 	 Validation Loss: 0.577                 	 F1 score: 0.7694560236337047


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.21it/s]


Epoch 36/40                 	 Training Loss: 0.042                 	 Validation Loss: 0.624                 	 F1 score: 0.7615759897313295


100%|██████████| 91/91 [00:34<00:00,  2.67it/s]
100%|██████████| 23/23 [00:02<00:00,  8.18it/s]


Epoch 37/40                 	 Training Loss: 0.038                 	 Validation Loss: 0.618                 	 F1 score: 0.7658171160268009


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  7.97it/s]


Epoch 38/40                 	 Training Loss: 0.026                 	 Validation Loss: 0.670                 	 F1 score: 0.7583747549474823


100%|██████████| 91/91 [00:34<00:00,  2.65it/s]
100%|██████████| 23/23 [00:02<00:00,  8.06it/s]


Epoch 39/40                 	 Training Loss: 0.019                 	 Validation Loss: 0.671                 	 F1 score: 0.7648929294783829


100%|██████████| 91/91 [00:34<00:00,  2.64it/s]
100%|██████████| 23/23 [00:02<00:00,  8.21it/s]


Epoch 40/40                 	 Training Loss: 0.014                 	 Validation Loss: 0.712                 	 F1 score: 0.7638981295785494


### Predictions

In [47]:
# # Load best model
# network_2 = Network()

# network_2.load_state_dict(torch.load(path))
# network_2.eval()

In [48]:
first_model =  BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
first_model.load_state_dict(torch.load('first_model.pt'))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

<All keys matched successfully>

In [49]:
# first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
# first_model.load_state_dict(torch.load('first_model.pt'))

In [50]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_am.load_state_dict(torch.load('model_am.pt'))

<All keys matched successfully>

In [51]:
# model_am =  BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=3)
# model_am.load_state_dict(torch.load('model_am.pt'))

In [52]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_ac.load_state_dict(torch.load('model_ac.pt'))

<All keys matched successfully>

In [53]:
# Load best model

custom_model_2 = CustomBERTKuri(first_model, model_am, model_ac, 3)
custom_model_2.load_state_dict(torch.load('best_model.pt'))

custom_model_2.to(device).eval()

CustomBERTKuri(
  (first_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, b

In [54]:
def predict(model, test_dataloader=None):
    
    """Prediction loop"""

    preds_l = []
    labels_l = []
    
    model.eval()

    for batch in test_dataloader:            
            
        # unpack batch             
        labels = batch['label'].to(device).flatten().tolist()
        spans = batch['spans'].to(device)
        input_ids = batch['input_ids'].to(device)
        
        inputs = input_ids, spans

        # get output
        
        raw_preds = model(inputs).to('cpu')
        # print(raw_preds.shape)
        raw_preds = raw_preds.detach()#.numpy()

        # Compute argmax
        
        predictions = torch.argmax(raw_preds, dim=2).flatten().tolist()
        preds_l.append(predictions)
        labels_l.append(labels)        
        
        del batch
            
    return flatten_list(preds_l), flatten_list(labels_l)

In [55]:
#test_preds, test_labels = predict(custom_model, test_dataloader)

In [56]:
#print(classification_report(test_labels, test_preds, digits=3))

In [57]:
# remove -100s

In [58]:
#test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [59]:
#print(classification_report(test_labels_l, test_preds_l, digits=3))

In [60]:
### -100 done!

In [61]:
test_preds, test_labels = predict(custom_model_2, test_dataloader)

In [62]:
print(classification_report(test_labels, test_preds, digits=3))

              precision    recall  f1-score   support

        -100      0.000     0.000     0.000      3033
           0      0.085     0.813     0.154       155
           1      0.162     0.630     0.257       303
           2      0.440     0.892     0.589       805

    accuracy                          0.241      4296
   macro avg      0.172     0.584     0.250      4296
weighted avg      0.097     0.241     0.134      4296



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [63]:
test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [64]:
print(classification_report(test_labels_l, test_preds_l, digits=3))

              precision    recall  f1-score   support

           0      0.840     0.813     0.826       155
           1      0.635     0.630     0.632       303
           2      0.884     0.892     0.888       805

    accuracy                          0.819      1263
   macro avg      0.786     0.778     0.782      1263
weighted avg      0.819     0.819     0.819      1263

