# Implementation of the Kuribayashi BERT minus model

## libraries

In [1]:
!pip install transformers --upgrade
!pip install ipywidgets
!pip install IProgress
!pip install datasets
!pip install torch-lr-finder

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding, default_data_collator, DataCollatorWithPadding

import torch
import torch.nn as nn

import numpy as np

from sklearn.metrics import classification_report
from sklearn.metrics import f1_score

import datasets

from torch.utils.data import DataLoader

from tqdm import tqdm
from operator import itemgetter

In [3]:
print(transformers.__version__)

4.25.1


## tokenizer

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [5]:
# tokenizer.model_max_length = 256

## data

In [6]:
# DATA_FOLDER = '/notebooks/Data/bert_sequence_classification'
DATA_FILE = '/notebooks/KURI-BERT/notebooks/full_formula_w_fts/Link_Identification_Task/pe_dataset_for_bert_minus_w_fts_combined_link_task.pt'
RESULTS_FOLDER = '/notebooks/KURI-BERT/results'

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
device

device(type='cuda')

## Load data

In [9]:
dataset = torch.load(DATA_FILE)

In [10]:
dataset

DatasetDict({
    train: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_fts_as_txt_list', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_w_fts_as_txt', 'ac_spans_new', 'sanity_new', 'am_spans_new', 'feature_spans_new'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_fts_as_txt_list', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_w_fts_as_txt', 'ac_spans_new', 'sanity_new', 'am_spans_new', 'feature_spans_new'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_fts_as_txt_list', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_sp

In [11]:
dataset['train']['paragraph_w_fts_as_txt'][0]

'Firstly, by paying taxes for public school, affluent people effectively contribute to narrowing down the gap between rich and poor [SEP] Structural Features: 2, Yes, No, No, No [SEP]. It is true that many poor families are not able to afford tuition fees for their kids to attend a course [SEP] Structural Features: 2, No, No, No, No [SEP]. With the the tax amount for which they pay, the rich may help a vast number of students from families with poverty background to continue their studies, and earn a better quality of life [SEP] Structural Features: 2, No, Yes, No, No [SEP].'

In [12]:
for l in dataset['train']['feature_spans_new']:
    print(len(l))

3
3
3
2
4
7
4
1
2
3
4
5
1
1
4
5
2
2
5
6
4
6
7
2
1
5
1
3
3
4
6
2
1
5
3
4
7
4
5
1
7
4
8
3
1
7
2
1
6
2
4
6
2
1
3
1
6
3
7
1
6
1
6
4
5
3
5
5
1
6
1
1
5
3
2
1
3
5
1
5
7
6
5
1
5
6
7
3
6
1
2
5
7
5
4
4
1
4
11
9
4
4
4
4
3
4
1
3
3
2
1
1
2
11
3
2
4
4
1
1
5
4
3
1
4
4
1
1
3
4
6
1
1
4
1
3
7
2
3
2
4
3
2
2
4
6
4
6
3
6
4
4
1
4
4
5
3
1
7
6
3
12
4
2
1
2
5
4
1
3
1
5
3
1
7
2
1
4
1
3
6
1
1
5
2
4
3
3
5
3
1
2
5
3
6
4
10
1
6
2
1
8
5
2
4
2
6
4
7
1
2
4
5
4
2
1
2
7
7
3
2
4
6
4
3
1
8
2
1
6
1
1
6
7
3
5
2
4
1
1
4
4
6
2
1
4
5
1
3
1
6
2
5
4
4
3
2
3
3
6
4
1
2
2
7
1
1
5
1
2
6
5
3
5
1
1
3
4
6
4
1
1
1
6
4
2
1
4
3
4
4
1
1
3
2
2
2
5
2
5
1
3
1
2
4
1
5
3
1
9
3
1
3
6
6
2
1
5
5
7
3
1
1
2
6
6
3
4
2
2
2
2
4
1
7
1
3
2
5
4
3
1
8
4
5
7
1
3
4
7
5
3
2
6
6
7
2
2
1
2
2
2
3
4
6
2
2
4
4
1
10
4
6
1
5
7
4
3
3
1
2
4
4
4
8
5
3
2
2
4
1
8
4
4
5
6
1
4
4
6
1
4
6
5
1
7
4
4
4
6
3
3
5
4
3
1
5
1
1
5
4
3
6
2
6
1
4
6
5
1
2
4
6
3
5
2
4
1
2
4
1
1
2
1
4
6
6
2
1
5
1
4
3
2
1
2
1
3
5
5
3
3
1
5
1
2
1
9
6
1
4
1
2
6
3
3
4
7
3
3
1
6
2
5
2
5
5
1
3
2
2
3
4
2
8
3
3
5

In [13]:
flat_list = [item for sublist in dataset['train']['feature_spans_new'] for item in sublist]

In [14]:
flat_list

[[23, 34],
 [59, 70],
 [110, 121],
 [47, 58],
 [95, 106],
 [127, 138],
 [25, 36],
 [56, 67],
 [100, 111],
 [21, 32],
 [69, 80],
 [14, 25],
 [59, 70],
 [86, 97],
 [110, 121],
 [9, 20],
 [35, 46],
 [65, 76],
 [90, 101],
 [113, 124],
 [148, 159],
 [184, 195],
 [23, 34],
 [44, 55],
 [70, 81],
 [115, 126],
 [43, 54],
 [22, 33],
 [52, 63],
 [19, 30],
 [60, 71],
 [84, 95],
 [25, 36],
 [62, 73],
 [98, 109],
 [126, 137],
 [21, 32],
 [46, 57],
 [73, 84],
 [97, 108],
 [128, 139],
 [31, 42],
 [61, 72],
 [31, 42],
 [68, 79],
 [113, 124],
 [158, 169],
 [19, 30],
 [49, 60],
 [84, 95],
 [115, 126],
 [146, 157],
 [81, 92],
 [114, 125],
 [32, 43],
 [53, 64],
 [23, 34],
 [48, 59],
 [78, 89],
 [107, 118],
 [137, 148],
 [16, 27],
 [43, 54],
 [79, 90],
 [104, 115],
 [131, 142],
 [164, 175],
 [18, 29],
 [54, 65],
 [81, 92],
 [117, 128],
 [16, 27],
 [37, 48],
 [69, 80],
 [97, 108],
 [135, 146],
 [171, 182],
 [19, 30],
 [55, 66],
 [95, 106],
 [126, 137],
 [166, 177],
 [203, 214],
 [248, 259],
 [24, 35],
 [50, 

In [15]:
flat_list_2 = [item for sublist in flat_list for item in sublist]

In [16]:
flat_list_2.sort()

In [17]:
flat_list_2[-25:]

[273,
 273,
 276,
 277,
 278,
 279,
 280,
 284,
 284,
 284,
 288,
 289,
 290,
 290,
 291,
 294,
 295,
 301,
 302,
 305,
 311,
 313,
 322,
 336,
 347]

### preprocessing

In [18]:
MAX_LENGTH = 0
MAX_SPAN = 0

for split in ['train', 'test', 'validation']:
    
    for col_name in ['am_spans_new', 'ac_spans_new', 'feature_spans_new']:
        
        for x in dataset[split][col_name]:
            
            if max(x,key=itemgetter(1))[1] > MAX_SPAN:
                
                MAX_SPAN = max(x,key=itemgetter(1))[1]
                MAX_SPAN = min(MAX_SPAN, tokenizer.model_max_length - 2)
            
            if len(x) > MAX_LENGTH:
                
                MAX_LENGTH = len(x)

In [19]:
MAX_SPAN

347

In [20]:
tokenizer.model_max_length

512

In [21]:
def get_padding(batch, padding_target):    
    
    if padding_target == 'am_spans':
        
        col_name = 'am_spans_new'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'ac_spans':
        
        col_name = 'ac_spans_new'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'fts_spans':
        
        col_name = 'feature_spans_new'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'label':
    
        col_name = 'paragraph_labels'
        padding_val = [-100] # -1 previously       
        max_length = MAX_LENGTH # max([len(l) for l in batch[col_name]]) # cause some batch had 4 x 10

    padded_spans = []

    for idx, span in enumerate(batch[col_name]):

        padded_span = batch[col_name][idx] + (max_length - len(span)) * padding_val
        padded_spans.append(padded_span)

    return padded_spans         

In [22]:
def get_combined_spans(am_spans_ll, ac_spans_ll, fts_spans_ll):
    
    spans_ll = []
    
    for am_spans, ac_spans, fts_spans in zip(am_spans_ll, ac_spans_ll, fts_spans_ll):
        
        spans = []
        
        for am_span, ac_span, fts_span in zip(am_spans, ac_spans, fts_spans):
            
            for idx in [0,1]:
                
                if am_span[idx] > MAX_SPAN:
                    am_span[idx] = MAX_SPAN
                if ac_span[idx] > MAX_SPAN:
                    ac_span[idx] = MAX_SPAN
                if fts_span[idx] > MAX_SPAN:
                    fts_span[idx] = MAX_SPAN

            span = [am_span, ac_span, fts_span]
            spans.extend(span)
            
        spans_ll.append(spans)

    return spans_ll

### tokenize 

In [23]:
# max_length = 200 for use in the max_length in the tokenizer so that the things are of equal dim.

In [24]:
def tokenize(batch):
    
    tokenized_text = tokenizer(batch['paragraph_w_fts_as_txt'], truncation=True, padding=True, max_length=512)
    tokenized_text['label'] = get_padding(batch, 'label')
    tokenized_text['am_spans'] = get_padding(batch, 'am_spans')
    tokenized_text['ac_spans'] = get_padding(batch, 'ac_spans')
    tokenized_text['fts_spans'] = get_padding(batch, 'fts_spans')
    tokenized_text['spans'] = get_combined_spans(tokenized_text['am_spans'], tokenized_text['ac_spans'], tokenized_text['fts_spans'])      
    
    return tokenized_text

In [25]:
dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset['train']))



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [26]:
dataset

DatasetDict({
    train: Dataset({
        features: ['ac_spans', 'ac_spans_new', 'am_spans', 'am_spans_new', 'attention_mask', 'essay_nr', 'feature_spans_new', 'fts_spans', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_fts_as_txt_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'paragraph_w_fts_as_txt', 'sanity_new', 'spans', 'split', 'token_type_ids'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['ac_spans', 'ac_spans_new', 'am_spans', 'am_spans_new', 'attention_mask', 'essay_nr', 'feature_spans_new', 'fts_spans', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_fts_as_txt_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'paragraph_w_fts_as_txt', 'sanity_new', 'spans', 'split', 'token_type_ids'],
        num_rows: 358
    })
    validation: Dataset({
        feat

In [27]:
dataset['train'][0]['spans']

[[0, 0],
 [2, 21],
 [23, 34],
 [-1, -1],
 [41, 57],
 [59, 70],
 [-1, -1],
 [73, 108],
 [110, 121],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1],
 [-1, -1]]

In [28]:
len(dataset['train'][0]['spans'])

36

In [29]:
dataset['train'][1]['label']

[1, 1, 0, -100, -100, -100, -100, -100, -100, -100, -100, -100]

In [30]:
dataset['train'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'ac_spans_new': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans_new': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'feature_spans_new': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'fts_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64',

In [31]:
dataset['test'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'ac_spans_new': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans_new': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'feature_spans_new': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'fts_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64',

In [32]:
dataset['test'].features['spans'] = datasets.Array2D(shape=(36, 2), dtype="int32")
dataset['train'].features['spans'] = datasets.Array2D(shape=(36, 2), dtype="int32")
dataset['validation'].features['spans'] = datasets.Array2D(shape=(36, 2), dtype="int32")

In [33]:
dataset = dataset.map(lambda batch: batch, batched=True)

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [34]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'spans', 'label'])

## span representation function

In [35]:
def minus_one(t):
    
    return torch.where(t == 0, 0, t-1)

def plus_one(t):
    
    return torch.where(t >= MAX_SPAN, MAX_SPAN, t+1) # changed from t == MAX_SPAN

In [36]:
def get_span_representations(outputs, spans):
    
    #print('spans shape', spans.shape)
    
    batch_size = spans.shape[0]
    
    nr_span_indices = spans.shape[1]
    
    #print('nr span indices', nr_span_indices)
    
    # nr_span_indices = 24 # xxx. hardcode just to check
    
    idx_l_ams = range(0, nr_span_indices, 3) # [0,2,4,6 etc]
    idx_l_acs = range(1, nr_span_indices, 3) # [1,3,5,7 etc]
    idx_l_fts = range(2, nr_span_indices, 3)
    
    am_spans = spans[:, idx_l_ams, :] + 1 # adds 1 to all span indices (both am and ac) to offset for the CLS token in the input_ids.
    ac_spans = spans[:, idx_l_acs, :] + 1
    fts_spans = spans[:, idx_l_fts, :] + 1
    
    
    # fix for [-1, +1] problem
    
    am_spans_minus_one = minus_one(am_spans) # xxx. added to solve bug 2
    am_spans_plus_one = plus_one(am_spans) # xxx. added to solve bug 2
    
    ac_spans_minus_one = minus_one(ac_spans) # xxx. added to solve bug 2
    ac_spans_plus_one = plus_one(ac_spans) # xxx. added to solve bug 2
    
    fts_spans_minus_one = minus_one(fts_spans) # xxx. added to solve bug 2
    fts_spans_plus_one = plus_one(fts_spans) # xxx. added to solve bug 2
    
    
    am_spans = am_spans.flatten(start_dim=1)
    ac_spans = ac_spans.flatten(start_dim=1)
    fts_spans = fts_spans.flatten(start_dim=1)
    
    nr_adus = ac_spans.shape[1] // 2
    
    
    am_spans_minus_one = am_spans_minus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    am_spans_plus_one = am_spans_plus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    
    ac_spans_minus_one = ac_spans_minus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    ac_spans_plus_one = ac_spans_plus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    
    fts_spans_minus_one = fts_spans_minus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    fts_spans_plus_one = fts_spans_plus_one.flatten(start_dim=1) # xxx. added to solve bug 2
    
    ############# FOR AMs #################
    
    outputs_am = outputs[:,am_spans,:]
    #print('outputs am juste directly from outputs:', outputs_am.shape)
    outputs_am = torch.cat([outputs_am[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am = outputs_am.reshape(batch_size, nr_adus * 2, -1)
    
    #print('outputs_ am after reshape:', outputs_am.shape)
    
    
    
    outputs_am_minus_one = outputs[:,am_spans_minus_one,:] # xxx. added to solve bug 2
    outputs_am_minus_one = torch.cat([outputs_am_minus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_am_minus_one = outputs_am_minus_one.reshape(batch_size, nr_adus * 2, -1) # xxx. added to solve bug 2
    
    outputs_am_plus_one = outputs[:,am_spans_plus_one,:] # xxx. added to solve bug 2
    outputs_am_plus_one = torch.cat([outputs_am_plus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_am_plus_one = outputs_am_plus_one.reshape(batch_size, nr_adus * 2, -1) # xxx. added to solve bug 2
    
    
    
    ### Now that we have outputs_am i.e. outputs at am_span indices, now create the four Kuri forumlas for AMs
    
    # ============== the corrected 1st one =================== 
    
    outputs_am_first_term = torch.cat([outputs_am[:,i+1,:] - outputs_am_minus_one[:,i,:] for i in range(0, nr_adus * 2, 2)], dim=1) # i + 1 here means j in kuri, i here means i in kuri
    outputs_am_first_term = outputs_am_first_term.reshape(batch_size, -1, 768)
    
     # ============== the corrected 2nd one ==================
    
    outputs_am_second_term = torch.cat([outputs_am[:,i,:] - outputs_am_plus_one[:,i+1,:] for i in range(0, nr_adus * 2, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
    outputs_am_second_term = outputs_am_second_term.reshape(batch_size, -1, 768)
    
    
        # ============== the corrected third one ================== 
    
    outputs_am_third_term = torch.cat([outputs_am_minus_one[:,i,:] for i in range(0, nr_adus * 2, 2)], dim=1)
    outputs_am_third_term = outputs_am_third_term.reshape(batch_size, -1, 768)
    

        # ============== the corrected fourth one ==================
    
    outputs_am_fourth_term = torch.cat([outputs_am_plus_one[:,i+1,:] for i in range(0, nr_adus * 2, 2)], dim=1) # changed + 2 to + 1 to make it run # changed from +1 to +2 to ensure +2 is not a problem for AMs
    outputs_am_fourth_term = outputs_am_fourth_term.reshape(batch_size, -1, 768)
    
    # ============== NOW CONCATENATE THEM =========
    
    
    am_minus_representations = torch.cat([outputs_am_first_term, outputs_am_second_term, outputs_am_third_term, outputs_am_fourth_term], dim=-1)   
    
    
    
    ### am minus span representation according to kuribayashi paper is now above.
    
    
    outputs_ac = outputs[:,ac_spans,:]
    outputs_ac = torch.cat([outputs_ac[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac = outputs_ac.reshape(batch_size, nr_adus * 2, -1)
    
    
    outputs_ac_minus_one = outputs[:,ac_spans_minus_one,:] # xxx. added to solve bug 2
    outputs_ac_minus_one = torch.cat([outputs_ac_minus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_ac_minus_one = outputs_ac_minus_one.reshape(batch_size, nr_adus * 2, -1) # xxx. added to solve bug 2
    
    outputs_ac_plus_one = outputs[:,ac_spans_plus_one,:] # xxx. added to solve bug 2
    outputs_ac_plus_one = torch.cat([outputs_ac_plus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_ac_plus_one = outputs_ac_plus_one.reshape(batch_size, nr_adus * 2, -1) # xxx. added to solve bug 2
    
    ### Now that we have outputs_ac i.e. outputs at ac_span indices, now create the four Kuri forumlas for ACs
    
    
    # ============== the corrected first one ===================
    
    outputs_ac_first_term = torch.cat([outputs_ac[:,i+1,:] - outputs_ac_minus_one[:,i,:] for i in range(0, nr_adus * 2, 2)], dim=1)
    outputs_ac_first_term = outputs_ac_first_term.reshape(batch_size, -1, 768)
    
        # ============== the correct second one ==================
    
    outputs_ac_second_term = torch.cat([outputs_ac[:,i,:] - outputs_ac_plus_one[:,i+1,:] for i in range(0, nr_adus * 2, 2)], dim=1) # changed + 2 to + 1 to make it run
    outputs_ac_second_term = outputs_ac_second_term.reshape(batch_size, -1, 768)
    
    
    # ============== the corrected third term ==================
    
    outputs_ac_third_term = torch.cat([outputs_ac_minus_one[:,i,:] for i in range(0, nr_adus * 2, 2)], dim=1)
    outputs_ac_third_term = outputs_ac_third_term.reshape(batch_size, -1, 768)

    
     # ============== the corrected fourth term ================== xxx. for .788
    
    outputs_ac_fourth_term = torch.cat([outputs_ac_plus_one[:,i+1,:] for i in range(0, nr_adus * 2, 2)], dim=1) # changed + 2 to + 1 to make it run
    outputs_ac_fourth_term = outputs_ac_fourth_term.reshape(batch_size, -1, 768)
    
    # ============== NOW CONCATENATE THEM =========
    
    
    ac_minus_representations = torch.cat([outputs_ac_first_term, outputs_ac_second_term, outputs_ac_third_term, outputs_ac_fourth_term], dim=-1)   
    
    
    
    ### ac minus span representation according to kuribayashi paper is now above.
    
    
     ############# FOR Fts #################
    
    outputs_fts = outputs[:,am_spans,:] # am spans for checking
    #print('outputs fts juste directly from the outputs:', outputs_fts.shape)
    outputs_fts = torch.cat([outputs_fts[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_fts = outputs_fts.reshape(batch_size, nr_adus * 2, -1)
    
    #print('output fts after reshape:', outputs_fts.shape)
    
    
    
    outputs_fts_minus_one = outputs[:,fts_spans_minus_one,:] # xxx. added to solve bug 2
    outputs_fts_minus_one = torch.cat([outputs_fts_minus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_fts_minus_one = outputs_fts_minus_one.reshape(batch_size, nr_adus * 2, -1) # xxx. added to solve bug 2
    
    outputs_fts_plus_one = outputs[:,fts_spans_plus_one,:] # xxx. added to solve bug 2
    outputs_fts_plus_one = torch.cat([outputs_fts_plus_one[i,i,:,:] for i in range(batch_size)], dim=0) # xxx. added to solve bug 2
    outputs_fts_plus_one = outputs_fts_plus_one.reshape(batch_size, nr_adus * 2, -1) # xxx. added to solve bug 2
    
    #print('outputs_fts_plus_one', outputs_fts_plus_one.shape)
    
    
    
    ### Now that we have outputs_am i.e. outputs at am_span indices, now create the four Kuri forumlas for AMs
    
    # ============== (batch_size x nr_adus x 768) =================== 
    
    outputs_fts_first_term = torch.cat([outputs_fts[:,i+1,:] - outputs_fts_minus_one[:,i,:] for i in range(0, nr_adus * 2, 2)], dim=1) # i + 1 here means j in kuri, i here means i in kuri
    outputs_fts_first_term = outputs_fts_first_term.reshape(batch_size, -1, 768)
    #print('outputs_fts_first_term:', outputs_fts_first_term.shape)
    
     # ============== the corrected 2nd one ==================
    
    outputs_fts_second_term = torch.cat([outputs_fts[:,i,:] - outputs_fts_plus_one[:,i+1,:] for i in range(0, nr_adus * 2, 2)], dim=1)
    outputs_fts_second_term = outputs_fts_second_term.reshape(batch_size, -1, 768)
    #print('outputs_fts_second_term:', outputs_fts_second_term.shape)
    
        # ============== the corrected third one ================== 
    
    outputs_fts_third_term = torch.cat([outputs_fts_minus_one[:,i,:] for i in range(0, nr_adus * 2, 2)], dim=1)
    outputs_fts_third_term = outputs_fts_third_term.reshape(batch_size, -1, 768)
    

        # ============== the corrected fourth one ==================
    
    outputs_fts_fourth_term = torch.cat([outputs_fts_plus_one[:,i+1,:] for i in range(0, nr_adus * 2, 2)], dim=1)
    outputs_fts_fourth_term = outputs_fts_fourth_term.reshape(batch_size, -1, 768)
    #print('outputs_fts_fourth_term:', outputs_fts_fourth_term.shape)
    
    # ============== NOW CONCATENATE THEM =========
    
    
    fts_minus_representations = torch.cat([outputs_fts_first_term, outputs_fts_second_term, outputs_fts_third_term, outputs_fts_fourth_term], dim=-1)   
    
    
    
    ### fts minus span representation according to kuribayashi paper is now above.
    
#     print('am rep final:', am_minus_representations.shape)
#     print('ac rep final:', ac_minus_representations.shape)
#     print('fts rep final:', fts_minus_representations.shape)
    

    
    return am_minus_representations, ac_minus_representations, fts_minus_representations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

### span representation function old

## custom BERT model

In [37]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_am, model_ac, model_fts, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
        self.intermediate_linear_am = nn.Linear(3072, 768)
        self.intermediate_linear_ac = nn.Linear(3072, 768) 
        self.intermediate_linear_fts = nn.Linear(3072, 768)
        
        self.model_am = model_am
        self.model_ac = model_ac
        self.model_fts = model_fts
        
        self.nr_classes = nr_classes
                
        self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size + self.model_fts.config.hidden_size, self.nr_classes)
        # self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size, self.nr_classes) 

    def forward(self, inputs):
        
        batch_tokenized, batch_spans = inputs 
        #outputs = self.first_model(batch_tokenized)[0]
        outputs = self.first_model(batch_tokenized, output_hidden_states=True)[1][12] # ** removed to correct error.
        # spans = batch_spans # remove this spans thing cause we are now giving it just the spans themselves.
        am_minus_representations, ac_minus_representations, fts_minus_representations = get_span_representations(outputs, batch_spans)
#         print('passed and gotten reps for all three successfully!')
#         print('am rep final in model:', am_minus_representations.shape)
#         print('ac rep final in model:', ac_minus_representations.shape)
#         print('fts rep final in model:', fts_minus_representations.shape)
        #am_minus_representations, ac_minus_representations = get_span_representations(outputs, batch_spans)
        
        # am_minus_representations = 
        
        
        #print('am minus rep:', am_minus_representations.shape)
        #print('ac minus rep:', ac_minus_representations.shape)
        #print('linear layer shape:', self.intermediate_linear_am.in_features)

        am_minus_representations = self.intermediate_linear_am(am_minus_representations)
        #print('passed first linear')
        ac_minus_representations = self.intermediate_linear_ac(ac_minus_representations)
        #print('passed second linear')
        fts_minus_representations = self.intermediate_linear_fts(fts_minus_representations)
        #print('passed third linear')
        
        #print('am minus rep:', am_minus_representations.shape)
        #print('ac minus rep:', ac_minus_representations.shape)
        #print('fts minus rep:', fts_minus_representations.shape)

        output_model_am = self.model_am(inputs_embeds = am_minus_representations)[0]
        output_model_ac = self.model_ac(inputs_embeds = ac_minus_representations)[0]
        output_model_fts = self.model_fts(inputs_embeds = fts_minus_representations)[0]

        adu_representations = torch.cat([output_model_am, output_model_ac, output_model_fts], dim=-1)
        #adu_representations = torch.cat([output_model_am, output_model_ac], dim=-1)
        # print("adu rep:", adu_representations.shape)
        output = self.fc(adu_representations)
        # print("model class output avant reshape:", output.shape)
        # output = output.reshape(-1, self.nr_classes)
        # print("model class output apres:", output.shape)
        return output

## Run

In [38]:
NB_EPOCHS = 40
BATCH_SIZE = 24

In [39]:
# first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
first_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
first_model.load_state_dict(torch.load('/notebooks/KURI-BERT/notebooks/full_formula_w_fts/icann_finetuned_work/link_identification_finetuned_model_new.pth'))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

<All keys matched successfully>

In [40]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [41]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [42]:
model_fts = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [43]:
# custom_model = CustomBERTKuri(first_model, model_am, model_ac, model_fts, 3)
custom_model = CustomBERTKuri(first_model, model_am, model_ac, model_fts, 2)

In [44]:
custom_model.to(device)

CustomBERTKuri(
  (first_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, b

In [45]:
loss = nn.CrossEntropyLoss(ignore_index=- 100)

In [46]:
optimizer = torch.optim.AdamW(custom_model.parameters(), lr=0.004328761281083062)

In [47]:
# 1.8738174228603844e-05
# 0.008111308307896872
# best learning rate found by the whole leslie business
# new best LR found= 9.999999999999997e-06

In [48]:
NR_BATCHES = len(dataset['train']) / BATCH_SIZE
num_training_steps = NB_EPOCHS * NR_BATCHES
num_warmup_steps = int(0.2 * num_training_steps)

In [49]:
def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    # return LambdaLR(optimizer, lr_lambda, last_epoch)

In [50]:
# commented for LR Finder. remove it from optimizer.
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

### create dataloaders

In [51]:
train_dataloader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, shuffle=True)

In [52]:
# xxx. delete datasets for memory work
# del dataset

In [53]:
def flatten_list(list_of_lists):
    return [x for sublist in list_of_lists for x in sublist]

In [54]:
def remove_dummy_labels(test_preds, test_labels):
    
    idxes = []
    test_labels_l = []
    for idx, val in enumerate(test_labels):
        if val != -100:
            idxes.append(idx)
            test_labels_l.append(val)
    
    test_preds_l = []
    for idx, val in enumerate(test_preds):
        for good_idx in idxes:
            if idx == good_idx:
                test_preds_l.append(val)
        
    return test_preds_l, test_labels_l

In [55]:
for i, b in enumerate(train_dataloader):
    if i == 40:
        break

In [56]:
b

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[  101,  1996,  4274,  ...,     0,     0,     0],
         [  101,  2559,  2013,  ...,     0,     0,     0],
         [  101,  1999,  7091,  ...,     0,     0,     0],
         ...,
         [  101,  2000,  4088,  ...,     0,     0,     0],
         [  101,  2035,  1999,  ...,     0,     0,     0],
         [  101, 11865,  8525,  ...,     0,     0,     0]]),
 'label': tensor([[   0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100],
         [   1,    1,    0,    1, -100, -100, -100, -100, -100, -100, -100, -100],
         [   0,    0,    1,    0,    1, -100, -100, -100, -100, -100, -100, -100],
         [   0,    1,    1, -100, -100, -100, -100, -100, -100, -100, -100, -100],
         [   1,    0,    1,    1

### training loop

In [57]:
from transformers import Trainer, TrainingArguments
from datasets import load_metric

In [58]:
import random
from torch_lr_finder import LRFinder

### LR Finder Leslie Smith 

## training 

In [59]:
def train(model, loss=None, optimizer=None, train_dataloader=None, val_dataloader=None, nb_epochs=20):
    """Training loop"""

    min_f1 = -torch.inf
    train_losses = []
    val_losses = []

    # Iterrate over epochs
    for e in range(nb_epochs):

        # Training
        train_loss = 0.0

        for i, batch in enumerate(tqdm(train_dataloader)):            
            
            #print(i)
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, spans
            
            # Reset gradients to 0
            optimizer.zero_grad()

            # Forward Pass
            outputs = model(inputs)
            
            # Compute training loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            train_loss += current_loss.detach().item()

            # Compute gradients
            current_loss.backward()

            # Update weights
            optimizer.step()            
            
            del batch
        
        scheduler.step()
            
        
        # Validation
        val_loss = 0.0
        
        # torch.cuda.empty_cache()
        
        # Put model in eval mode
        model.eval()
        
        preds_l = []
        labels_l = []
        
        for batch in tqdm(val_dataloader):            
            
            # unpack batch             
            labels = batch['label'].to(device)
            spans = batch['spans'].to(device)
            input_ids = batch['input_ids'].to(device)
            
            inputs = input_ids, spans
            
            # Forward Pass
            outputs = model(inputs)

            # Compute validation loss
            current_loss = loss(outputs.flatten(0,1), labels.flatten())
            val_loss += current_loss.detach().item()
            
            preds_for_f1 = torch.argmax(outputs, dim=2).flatten().tolist()
            labels_for_f1 = labels.flatten().tolist()
            
            preds_l.append(preds_for_f1)
            labels_l.append(labels_for_f1)
            
            del batch
        
        # Prints
        
        preds_l = flatten_list(preds_l)
        labels_l = flatten_list(labels_l)
        
        preds_l, labels_l = remove_dummy_labels(preds_l, labels_l)
        
        f1_score_epoch = f1_score(preds_l, labels_l, average='macro')        
        
        print(f"Epoch {e+1}/{nb_epochs} \
                \t Training Loss: {train_loss/len(train_dataloader):.3f} \
                \t Validation Loss: {val_loss/len(val_dataloader):.3f} \
                \t F1 score: {f1_score_epoch}")
        
        train_losses.append(train_loss/len(train_dataloader))
        val_losses.append(val_loss/len(val_dataloader))
        

        # Save model if val loss decreases
        if f1_score_epoch > min_f1:

            min_f1 = f1_score_epoch
            torch.save(model.first_model.state_dict(), 'first_model.pt')
            torch.save(model.model_am.state_dict(), 'model_am.pt')
            torch.save(model.model_ac.state_dict(), 'model_ac.pt')
            torch.save(model.model_fts.state_dict(), 'model_fts.pt')
            torch.save(model.state_dict(), 'best_model.pt')
            
    return train_losses, val_losses

In [60]:
train_losses, val_losses = train(custom_model, loss, optimizer, train_dataloader, val_dataloader, NB_EPOCHS)

100%|██████████| 46/46 [00:49<00:00,  1.08s/it]
100%|██████████| 12/12 [00:03<00:00,  3.24it/s]


Epoch 1/40                 	 Training Loss: 0.878                 	 Validation Loss: 0.917                 	 F1 score: 0.28625570776255704


100%|██████████| 46/46 [00:49<00:00,  1.08s/it]
100%|██████████| 12/12 [00:03<00:00,  3.17it/s]


Epoch 2/40                 	 Training Loss: 1.044                 	 Validation Loss: 0.545                 	 F1 score: 0.70889331342293


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.24it/s]


Epoch 3/40                 	 Training Loss: 0.492                 	 Validation Loss: 0.462                 	 F1 score: 0.7887607549889635


100%|██████████| 46/46 [00:51<00:00,  1.11s/it]
100%|██████████| 12/12 [00:03<00:00,  3.19it/s]


Epoch 4/40                 	 Training Loss: 0.407                 	 Validation Loss: 0.428                 	 F1 score: 0.7943125303860801


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.17it/s]


Epoch 5/40                 	 Training Loss: 0.293                 	 Validation Loss: 0.436                 	 F1 score: 0.7971028610393258


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.25it/s]


Epoch 6/40                 	 Training Loss: 0.188                 	 Validation Loss: 0.556                 	 F1 score: 0.8034173669467787


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.24it/s]


Epoch 7/40                 	 Training Loss: 0.163                 	 Validation Loss: 0.662                 	 F1 score: 0.7794646313810077


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.21it/s]


Epoch 8/40                 	 Training Loss: 0.115                 	 Validation Loss: 0.725                 	 F1 score: 0.7713508743494156


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.20it/s]


Epoch 9/40                 	 Training Loss: 0.120                 	 Validation Loss: 0.681                 	 F1 score: 0.7917715595900987


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.22it/s]


Epoch 10/40                 	 Training Loss: 0.106                 	 Validation Loss: 0.875                 	 F1 score: 0.7960437309030537


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.21it/s]


Epoch 11/40                 	 Training Loss: 0.174                 	 Validation Loss: 0.652                 	 F1 score: 0.7654474479291997


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.19it/s]


Epoch 12/40                 	 Training Loss: 0.123                 	 Validation Loss: 0.770                 	 F1 score: 0.7706691046615295


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.17it/s]


Epoch 13/40                 	 Training Loss: 0.094                 	 Validation Loss: 0.778                 	 F1 score: 0.7818118369625908


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.21it/s]


Epoch 14/40                 	 Training Loss: 0.064                 	 Validation Loss: 0.912                 	 F1 score: 0.7982660138718891


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.22it/s]


Epoch 15/40                 	 Training Loss: 0.039                 	 Validation Loss: 1.013                 	 F1 score: 0.7839642208970703


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.23it/s]


Epoch 16/40                 	 Training Loss: 0.030                 	 Validation Loss: 0.936                 	 F1 score: 0.7859373145502186


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.22it/s]


Epoch 17/40                 	 Training Loss: 0.111                 	 Validation Loss: 0.761                 	 F1 score: 0.7835257556416659


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.22it/s]


Epoch 18/40                 	 Training Loss: 0.188                 	 Validation Loss: 0.830                 	 F1 score: 0.7675797438571161


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.16it/s]


Epoch 19/40                 	 Training Loss: 0.156                 	 Validation Loss: 0.582                 	 F1 score: 0.7542268159963148


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.25it/s]


Epoch 20/40                 	 Training Loss: 0.112                 	 Validation Loss: 0.889                 	 F1 score: 0.7395100502512563


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.15it/s]


Epoch 21/40                 	 Training Loss: 0.160                 	 Validation Loss: 0.701                 	 F1 score: 0.7542727149871615


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.20it/s]


Epoch 22/40                 	 Training Loss: 0.222                 	 Validation Loss: 0.687                 	 F1 score: 0.7532875131264013


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.14it/s]


Epoch 23/40                 	 Training Loss: 0.346                 	 Validation Loss: 0.640                 	 F1 score: 0.6270583650603787


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.22it/s]


Epoch 24/40                 	 Training Loss: 0.611                 	 Validation Loss: 0.625                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.25it/s]


Epoch 25/40                 	 Training Loss: 0.621                 	 Validation Loss: 0.628                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.25it/s]


Epoch 26/40                 	 Training Loss: 0.599                 	 Validation Loss: 0.609                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.22it/s]


Epoch 27/40                 	 Training Loss: 0.601                 	 Validation Loss: 0.609                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.12it/s]


Epoch 28/40                 	 Training Loss: 0.601                 	 Validation Loss: 0.614                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.24it/s]


Epoch 29/40                 	 Training Loss: 0.597                 	 Validation Loss: 0.636                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.24it/s]


Epoch 30/40                 	 Training Loss: 0.597                 	 Validation Loss: 0.610                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.18it/s]


Epoch 31/40                 	 Training Loss: 0.602                 	 Validation Loss: 0.603                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.18it/s]


Epoch 32/40                 	 Training Loss: 0.594                 	 Validation Loss: 0.613                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.23it/s]


Epoch 33/40                 	 Training Loss: 0.599                 	 Validation Loss: 0.607                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.17it/s]


Epoch 34/40                 	 Training Loss: 0.596                 	 Validation Loss: 0.629                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.26it/s]


Epoch 35/40                 	 Training Loss: 0.598                 	 Validation Loss: 0.612                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.22it/s]


Epoch 36/40                 	 Training Loss: 0.600                 	 Validation Loss: 0.605                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.10s/it]
100%|██████████| 12/12 [00:03<00:00,  3.20it/s]


Epoch 37/40                 	 Training Loss: 0.598                 	 Validation Loss: 0.607                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.20it/s]


Epoch 38/40                 	 Training Loss: 0.594                 	 Validation Loss: 0.599                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.24it/s]


Epoch 39/40                 	 Training Loss: 0.593                 	 Validation Loss: 0.595                 	 F1 score: 0.6734196953486198


100%|██████████| 46/46 [00:50<00:00,  1.09s/it]
100%|██████████| 12/12 [00:03<00:00,  3.19it/s]


Epoch 40/40                 	 Training Loss: 0.601                 	 Validation Loss: 0.604                 	 F1 score: 0.6734196953486198


### Predictions

In [61]:
# # Load best model
# network_2 = Network()

# network_2.load_state_dict(torch.load(path))
# network_2.eval()

In [62]:
test_dataloader = DataLoader(dataset['test'], batch_size=BATCH_SIZE, shuffle=True)

In [63]:
first_model =  BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)
first_model.load_state_dict(torch.load('first_model.pt'))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

<All keys matched successfully>

In [64]:
# first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
# first_model.load_state_dict(torch.load('first_model.pt'))

In [65]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_am.load_state_dict(torch.load('model_am.pt'))

<All keys matched successfully>

In [66]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_ac.load_state_dict(torch.load('model_ac.pt'))

<All keys matched successfully>

In [67]:
model_fts = BertModel(BertConfig.from_pretrained("bert-base-uncased"))
model_fts.load_state_dict(torch.load('model_fts.pt'))

<All keys matched successfully>

In [68]:
# Load best model

# custom_model_2 = CustomBERTKuri(first_model, model_am, model_ac, model_fts, 3)
custom_model_2 = CustomBERTKuri(first_model, model_am, model_ac, model_fts, 2)
custom_model_2.load_state_dict(torch.load('best_model.pt'))

custom_model_2.to(device).eval()

CustomBERTKuri(
  (first_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, b

In [69]:
def predict(model, test_dataloader=None):
    
    """Prediction loop"""

    preds_l = []
    labels_l = []
    
    model.eval()

    for batch in test_dataloader:            
            
        # unpack batch             
        labels = batch['label'].to(device).flatten().tolist()
        spans = batch['spans'].to(device)
        input_ids = batch['input_ids'].to(device)
        
        inputs = input_ids, spans

        # get output
        
        raw_preds = model(inputs).to('cpu')
        # print(raw_preds.shape)
        raw_preds = raw_preds.detach()#.numpy()

        # Compute argmax
        
        predictions = torch.argmax(raw_preds, dim=2).flatten().tolist()
        preds_l.append(predictions)
        labels_l.append(labels)        
        
        del batch
            
    return flatten_list(preds_l), flatten_list(labels_l)

In [70]:
#test_preds, test_labels = predict(custom_model, test_dataloader)

In [71]:
#print(classification_report(test_labels, test_preds, digits=3))

In [72]:
# remove -100s

In [73]:
#test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [74]:
#print(classification_report(test_labels_l, test_preds_l, digits=3))

In [75]:
### -100 done!

In [76]:
test_preds, test_labels = predict(custom_model_2, test_dataloader)

In [77]:
print(classification_report(test_labels, test_preds, digits=3))

              precision    recall  f1-score   support

        -100      0.000     0.000     0.000      3033
           0      0.110     0.697     0.191       518
           1      0.650     0.894     0.753       745

    accuracy                          0.239      4296
   macro avg      0.253     0.530     0.314      4296
weighted avg      0.126     0.239     0.153      4296



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [78]:
test_preds_l, test_labels_l = remove_dummy_labels(test_preds, test_labels)

In [79]:
print(classification_report(test_labels_l, test_preds_l, digits=3))

              precision    recall  f1-score   support

           0      0.820     0.697     0.754       518
           1      0.809     0.894     0.849       745

    accuracy                          0.813      1263
   macro avg      0.815     0.795     0.802      1263
weighted avg      0.814     0.813     0.810      1263



precision    recall  f1-score   support

           0      0.893     0.865     0.879       155
           1      0.663     0.683     0.673       303
           2      0.896     0.892     0.894       805

    accuracy                          0.838      1263
   macro avg      0.818     0.813     0.815      1263
weighted avg      0.840     0.838     0.839      1263

12 batch size, 0.006579332246575687