# Implementation of the Kuribayashi BERT minus model

## libraries

In [9]:
!pip install transformers
!pip install ipywidgets
!pip install IProgress
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [10]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding

import torch
import torch.nn as nn

import datasets

## tokenizer

In [11]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## data

In [12]:
# DATA_FOLDER = '/notebooks/Data/bert_sequence_classification'
DATA_FILE = '/notebooks/KURI-BERT/data/pe_dataset_for_bert_minus.pt'
# RESULTS_FOLDER = '/notebooks/Results/bert_sequence_classification'

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [14]:
device

device(type='cuda')

## Load data

In [15]:
dataset = torch.load(DATA_FILE)

In [16]:
dataset

Dataset({
    features: ['paragraph', 'paragraph_ACs', 'paragraph_AMs', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans', 'split', 'essay_nr'],
    num_rows: 1719
})

In [17]:
dataset['paragraph_ac_spans']

[[[87, 96]],
 [[4, 23], [25, 51], [53, 94], [96, 116]],
 [[5, 19], [27, 32], [36, 56], [70, 131], [148, 158]],
 [[24, 37]],
 [[59, 75]],
 [[2, 21], [23, 55], [57, 75], [77, 99], [107, 115]],
 [[2, 10], [12, 29], [34, 48], [50, 78], [80, 99], [107, 117]],
 [[4, 21], [30, 50]],
 [[48, 66]],
 [[2, 21], [23, 55], [57, 79], [82, 90]],
 [[2, 17], [22, 57], [59, 77], [81, 92]],
 [[10, 26]],
 [[21, 29], [35, 50]],
 [[10, 25], [27, 76], [78, 97], [103, 112]],
 [[8, 20], [22, 69], [75, 98], [106, 111]],
 [[7, 20]],
 [[49, 59]],
 [[2, 25], [27, 63]],
 [[2, 23], [25, 42], [44, 56], [60, 69]],
 [[2, 13], [15, 28], [32, 50]],
 [[3, 15], [19, 38]],
 [[22, 26], [28, 42]],
 [[4, 20], [22, 40], [45, 53], [57, 69], [71, 84]],
 [[5, 22], [27, 35], [37, 56]],
 [[0, 8],
  [10, 27],
  [31, 38],
  [40, 57],
  [59, 66],
  [68, 82],
  [84, 97],
  [100, 110]],
 [[3, 20]],
 [[63, 73]],
 [[4, 18], [20, 61], [63, 101], [105, 119]],
 [[5, 26], [34, 49], [85, 97], [103, 126]],
 [[6, 29], [31, 47], [53, 67]],
 [[78, 8

### tokenization

In [20]:
def tokenize(dataset):
    
    tokenized_text = tokenizer(dataset['texts'], padding=True, return_tensors="pt")
    max_length = max([len(l) for l in dataset['spans']])
    span_ll = []
    
    for idx, span in enumerate(dataset['spans']):
        
        tmp_l = dataset['spans'][idx] + (max_length - len(span)) * [[0,0]]
        span_ll.append(tmp_l)

    tokenized_text['spans'] = torch.tensor(span_ll)
    tokenized_text['labels'] = torch.tensor(dataset['labels'])
    return tokenized_text

In [21]:
tokenized_input = tokenize(dataset)

## batching

In [22]:
batch = tokenize(dataset)

In [23]:
batch_tokenized = BatchEncoding({k:batch[k] for k in ('input_ids','token_type_ids','attention_mask')})

In [24]:
batch_spans = BatchEncoding({'spans':batch['spans']})

In [25]:
batch_labels = BatchEncoding({'labels':batch['labels']})

In [30]:
batch_spans['spans'].shape

torch.Size([2, 4, 2])

## span representation function

In [26]:
def get_span_representations(outputs, spans):

    batch_size = spans.shape[0]
    nr_span_indices = spans.shape[1]
    
    idx_l_ams = range(0, spans.shape[1], 2) 
    idx_l_acs = range(1, spans.shape[1], 2)  
    
    am_spans = spans[:, idx_l_ams, :] + 1
    ac_spans = spans[:, idx_l_acs, :] + 1
    
    am_spans = am_spans.flatten(start_dim=1)
    ac_spans = ac_spans.flatten(start_dim=1)
    
    outputs_am = outputs[:,am_spans,:]
    outputs_am = torch.cat([outputs_am[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am = outputs_am.reshape(batch_size, nr_span_indices, -1)
    
    outputs_am_r = torch.cat([outputs_am[:,i,:] - outputs_am[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_r = outputs_am_r.reshape(batch_size, -1, 768)
    
    outputs_am_l = torch.cat([outputs_am[:,i+1,:] - outputs_am[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_l = outputs_am_r.reshape(batch_size, -1, 768)
    
    am_minus_representations = torch.cat([outputs_am_r, outputs_am_l], dim=-1)
    
    outputs_ac = outputs[:,ac_spans,:]
    outputs_ac = torch.cat([outputs_ac[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac = outputs_ac.reshape(batch_size, nr_span_indices, -1)
    
    outputs_ac_r = torch.cat([outputs_ac[:,i,:] - outputs_ac[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_r = outputs_ac_r.reshape(batch_size, -1, 768)
    
    outputs_ac_l = torch.cat([outputs_ac[:,i+1,:] - outputs_ac[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_l = outputs_ac_l.reshape(batch_size, -1, 768)
    
    ac_minus_representations = torch.cat([outputs_ac_r, outputs_ac_l], dim=-1)
    
    return am_minus_representations, ac_minus_representations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

## custom BERT model

In [27]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_am, model_ac, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
        self.intermediate_linear_am = nn.Linear(1536, 768)
        self.intermediate_linear_ac = nn.Linear(1536, 768)        
        
        self.model_am = model_am
        self.model_ac = model_ac
        
        self.nr_classes = nr_classes
                
        self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size, self.nr_classes)        

    def forward(self, batch_tokenized, batch_spans):
        
        outputs = self.first_model(**batch_tokenized)[0]
        spans = batch_spans['spans']
        am_minus_representations, ac_minus_representations = get_span_representations(outputs, spans)

        am_minus_representations = self.intermediate_linear_am(am_minus_representations)
        ac_minus_representations = self.intermediate_linear_am(ac_minus_representations)

        output_model_am = self.model_am(inputs_embeds = am_minus_representations)[0]
        output_model_ac = self.model_ac(inputs_embeds = ac_minus_representations)[0]

        adu_representations = torch.cat([output_model_am, output_model_ac], dim=-1)

        output = self.fc(adu_representations)
        
        output = output.reshape(-1, self.nr_classes)
            
        return output

## Run

In [28]:
first_model = BertModel(BertConfig.from_pretrained("bert-base-cased"))

In [29]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-cased"))

In [30]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-cased"))

In [31]:
custom_model = CustomBERTKuri(first_model, model_am, model_ac, 3)

In [32]:
model_output = custom_model(batch_tokenized, batch_spans)

In [33]:
model_output.shape

torch.Size([4, 3])

In [34]:
labels = batch_labels['labels']

In [35]:
labels = labels.flatten()

In [36]:
criterion = torch.nn.CrossEntropyLoss()

In [37]:
optimizer = torch.optim.Adam(custom_model.parameters(), lr=0.01)

In [38]:
loss = criterion(model_output, labels)

In [39]:
optimizer.zero_grad()

In [40]:
loss.backward() 

In [41]:
optimizer.step()