# Implementation of the Kuribayashi BERT minus model

## libraries

In [1]:
!pip install transformers
!pip install ipywidgets
!pip install IProgress
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding

import torch
import torch.nn as nn

import numpy as np

import datasets

## tokenizer

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## data

In [4]:
# DATA_FOLDER = '/notebooks/Data/bert_sequence_classification'
DATA_FILE = '/notebooks/KURI-BERT/data/pe_dataset_for_bert_minus.pt'
RESULTS_FOLDER = '/notebooks/KURI-BERT/results'

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
device

device(type='cuda')

## Load data

In [7]:
dataset = torch.load(DATA_FILE)

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['paragraph', 'paragraph_components_list', 'paragraph_labels_list', 'paragraph_markers_list', 'split', 'essay_nr', 'paragraph_labels', 'paragraph_ac_spans', 'paragraph_am_spans'],
        num_rows: 273
    })
})

### preprocessing

In [9]:
MAX_LENGTH = 0

for split in ['train', 'test', 'validation']:
    
    for col_name in ['paragraph_am_spans', 'paragraph_ac_spans']:
        
        for x in dataset[split][col_name]:
        
            if len(x) > MAX_LENGTH:
                
                MAX_LENGTH = len(x)

In [10]:
def get_padding(batch, padding_target):    
    
    if padding_target == 'am_spans':
        
        col_name = 'paragraph_am_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'ac_spans':
        
        col_name = 'paragraph_ac_spans'
        padding_val = [[-1,-1]]
        max_length = MAX_LENGTH
        
    elif padding_target == 'label':
    
        col_name = 'paragraph_labels'
        padding_val = [-100] # -1 previously       
        max_length = MAX_LENGTH # max([len(l) for l in batch[col_name]]) # cause some batch had 4 x 10

    padded_spans = []

    for idx, span in enumerate(batch[col_name]):

        padded_span = batch[col_name][idx] + (max_length - len(span)) * padding_val
        padded_spans.append(padded_span)

    return padded_spans         

In [11]:
def get_combined_spans(am_spans_ll, ac_spans_ll):
    
    spans_ll = []
    
    for am_spans, ac_spans in zip(am_spans_ll, ac_spans_ll):
        
        spans = []
        
        for am_span, ac_span in zip(am_spans, ac_spans):

            span = [am_span, ac_span]
            spans.extend(span)
            
        spans_ll.append(spans)

    return spans_ll

### tokenize 

In [12]:
def tokenize(batch):
    
    tokenized_text = tokenizer(batch['paragraph'], truncation=True, padding=True, max_length=512)
    tokenized_text['label'] = get_padding(batch, 'label')
    tokenized_text['am_spans'] = get_padding(batch, 'am_spans')
    tokenized_text['ac_spans'] = get_padding(batch, 'ac_spans')
    tokenized_text['spans'] = get_combined_spans(tokenized_text['am_spans'], tokenized_text['ac_spans'])      
    
    return tokenized_text

In [13]:
dataset = dataset.map(tokenize, batched=True, batch_size=len(dataset['train']))



  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [14]:
dataset

DatasetDict({
    train: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_ids'],
        num_rows: 1088
    })
    test: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_ids'],
        num_rows: 358
    })
    validation: Dataset({
        features: ['ac_spans', 'am_spans', 'attention_mask', 'essay_nr', 'input_ids', 'label', 'paragraph', 'paragraph_ac_spans', 'paragraph_am_spans', 'paragraph_components_list', 'paragraph_labels', 'paragraph_labels_list', 'paragraph_markers_list', 'spans', 'split', 'token_type_

In [15]:
dataset['train'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'paragraph': Value(dtype='string', id=None),
 'paragraph_ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_components_list': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_labels': S

In [16]:
dataset['test'].features

{'ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'essay_nr': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'label': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'paragraph': Value(dtype='string', id=None),
 'paragraph_ac_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_am_spans': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_components_list': Sequence(feature=Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), length=-1, id=None),
 'paragraph_labels': S

In [17]:
dataset['test'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['train'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")
dataset['validation'].features['spans'] = datasets.Array2D(shape=(24, 2), dtype="int32")

In [18]:
dataset = dataset.map(lambda batch: batch, batched=True)

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [19]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'token_type_ids', 'spans', 'label'])

In [20]:
dataset['train']['input_ids'][0:4]

tensor([[  101, 15847,  1010,  2011,  7079,  7773,  2005,  2270,  2082,  1010,
         22666,  2111,  6464,  9002,  2000, 21978,  2091,  1996,  6578,  2090,
          4138,  1998,  3532,  1012,  2009,  2003,  2995,  2008,  2116,  3532,
          2945,  2024,  2025,  2583,  2000,  8984, 15413,  9883,  2005,  2037,
          4268,  2000,  5463,  1037,  2607,  1012,  2007,  1996,  1996,  4171,
          3815,  2005,  2029,  2027,  3477,  1010,  1996,  4138,  2089,  2393,
          1037,  6565,  2193,  1997,  2493,  2013,  2945,  2007,  5635,  4281,
          2000,  3613,  2037,  2913,  1010,  1998,  7796,  1037,  2488,  3737,
          1997,  2166,  1012,   102,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [21]:
dataset['validation']['label'][:].shape

torch.Size([273, 12])

In [22]:
dataset['train']['label'][:]

tensor([[   1,    2,    2,  ..., -100, -100, -100],
        [   2,    2,    1,  ..., -100, -100, -100],
        [   2,    2,    1,  ..., -100, -100, -100],
        ...,
        [   1,    0, -100,  ..., -100, -100, -100],
        [   1,    2,    2,  ..., -100, -100, -100],
        [   1,    2,    2,  ..., -100, -100, -100]])

In [23]:
dataset['validation']['label'][:]

tensor([[   1,    2,    2,  ..., -100, -100, -100],
        [   1,    2,    2,  ..., -100, -100, -100],
        [   0,    1, -100,  ..., -100, -100, -100],
        ...,
        [   0, -100, -100,  ..., -100, -100, -100],
        [   2,    2,    1,  ..., -100, -100, -100],
        [   2,    1,    2,  ..., -100, -100, -100]])

## span representation function

In [24]:
def get_span_representations(outputs, spans):

    batch_size = spans.shape[0]
    nr_span_indices = spans.shape[1]
    
    idx_l_ams = range(0, spans.shape[1], 2) 
    idx_l_acs = range(1, spans.shape[1], 2)  
    
    am_spans = spans[:, idx_l_ams, :] + 1
    ac_spans = spans[:, idx_l_acs, :] + 1
    
    am_spans = am_spans.flatten(start_dim=1)
    ac_spans = ac_spans.flatten(start_dim=1)
    
    outputs_am = outputs[:,am_spans,:]
    outputs_am = torch.cat([outputs_am[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am = outputs_am.reshape(batch_size, nr_span_indices, -1)
    
    outputs_am_r = torch.cat([outputs_am[:,i,:] - outputs_am[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_r = outputs_am_r.reshape(batch_size, -1, 768)
    
    outputs_am_l = torch.cat([outputs_am[:,i+1,:] - outputs_am[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_l = outputs_am_r.reshape(batch_size, -1, 768)
    
    am_minus_representations = torch.cat([outputs_am_r, outputs_am_l], dim=-1)
    
    outputs_ac = outputs[:,ac_spans,:]
    outputs_ac = torch.cat([outputs_ac[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac = outputs_ac.reshape(batch_size, nr_span_indices, -1)
    
    outputs_ac_r = torch.cat([outputs_ac[:,i,:] - outputs_ac[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_r = outputs_ac_r.reshape(batch_size, -1, 768)
    
    outputs_ac_l = torch.cat([outputs_ac[:,i+1,:] - outputs_ac[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_l = outputs_ac_l.reshape(batch_size, -1, 768)
    
    ac_minus_representations = torch.cat([outputs_ac_r, outputs_ac_l], dim=-1)
    
    return am_minus_representations, ac_minus_representations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

## custom BERT model

In [25]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_am, model_ac, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
        self.intermediate_linear_am = nn.Linear(1536, 768)
        self.intermediate_linear_ac = nn.Linear(1536, 768)        
        
        self.model_am = model_am
        self.model_ac = model_ac
        
        self.nr_classes = nr_classes
                
        self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size, self.nr_classes)        

    def forward(self, batch_tokenized, batch_spans):
        
        outputs = self.first_model(batch_tokenized)[0] # ** removed to correct error.
        # spans = batch_spans # remove this spans thing cause we are now giving it just the spans themselves.
        am_minus_representations, ac_minus_representations = get_span_representations(outputs, batch_spans)

        am_minus_representations = self.intermediate_linear_am(am_minus_representations)
        ac_minus_representations = self.intermediate_linear_am(ac_minus_representations)

        output_model_am = self.model_am(inputs_embeds = am_minus_representations)[0]
        output_model_ac = self.model_ac(inputs_embeds = ac_minus_representations)[0]

        adu_representations = torch.cat([output_model_am, output_model_ac], dim=-1)
        # print("adu rep:", adu_representations.shape)
        output = self.fc(adu_representations)
        # print("model class output avant reshape:", output.shape)
        output = output.reshape(-1, self.nr_classes)
        # print("model class output apres:", output.shape)
        return output

## Run

In [26]:
# dummy_batch = dataset['train'][0:4]

In [27]:
# batch_tokenized = BatchEncoding({k:dummy_batch[k] for k in ('input_ids','token_type_ids','attention_mask')})

In [28]:
# batch_spans = BatchEncoding({'spans':dummy_batch['spans']})

In [29]:
# batch_labels = BatchEncoding({'labels':dummy_batch['labels']})

In [30]:
####

In [31]:
first_model = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [32]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [33]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-uncased"))

In [34]:
custom_model = CustomBERTKuri(first_model, model_am, model_ac, 3)

In [35]:
# custom_model.to(device)

In [36]:
# model_output = custom_model(batch_tokenized, batch_spans)

In [37]:
# model_output.shape

In [38]:
# batch_labels['labels'].shape

In [39]:
# we have 'de-padify' things to get the real answers
# no need. the -100s of labels takes care of it.

In [40]:
# labels = batch_labels['labels']

In [41]:
# labels = labels.flatten()

In [42]:
# labels

In [43]:
# criterion = torch.nn.CrossEntropyLoss()

In [44]:
# optimizer = torch.optim.Adam(custom_model.parameters(), lr=0.01)

In [45]:
# loss = criterion(model_output, labels)

In [46]:
# loss

In [47]:
# optimizer.zero_grad()

In [48]:
# loss.backward() 

In [49]:
# optimizer.step()

### trainer

In [50]:
from transformers import Trainer, TrainingArguments
from datasets import load_metric

In [51]:
# https://huggingface.co/transformers/main_classes/trainer.html
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        #print("return outputs:", return_outputs)
        # print(inputs)
        # print(inputs['input_ids'])
        # print(inputs['spans'])
        labels = inputs["labels"]
        labels = labels.flatten() # xxx.
        #print("labels:", labels.shape)
        # print("labels flattened:", labels.shape)
        # print(labels) # ok this is correct now.
        # correct the output and loss computation.
        # so batch size is 4. see if you can define batch_tokenized and batch_spans cause that is what the custom model needs. but we already have that!
        
        
        outputs = model(inputs['input_ids'], inputs['spans'])
        #print("outputs:", outputs.shape)
        # ok. so this input_ids thing is a tensor. needs a dict or some mapping.
        
        # nice! so we have outputs now. now lets see what the fuck is it's shape.
        
        # print("outputs:", outputs.shape) # 48 x 3. huh. why though? what was I expecting here? ah no! correct. works out with the dummy example.
        
        
        # logits = outputs.get('logits') # no need for logits now.
        
        loss_fct = nn.CrossEntropyLoss()#(weight=class_weights)
        # loss = loss_fct(outputs, labels.flatten())
        loss = loss_fct(outputs, labels) # xxx
        #print("loss:", loss)
        return (loss, outputs) if return_outputs else loss

In [52]:
metric = load_metric('f1')

def compute_metrics(eval_pred):
    #print("eval preds:", eval_pred)
    logits, labels = eval_pred
    # print("logits", logits.shape)
    print("labels", labels)
    predictions = np.argmax(logits, axis=-1)
    print("predzis:", predictions, len(predictions))
    # print("labelzis:", labels.shape)
    
    return metric.compute(predictions=predictions, references=labels, average='macro')

In [53]:
metric

Metric(name: "f1", features: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)}, usage: """
Args:
    predictions: Predicted labels, as returned by a model.
    references: Ground truth labels.
    labels: The set of labels to include when average != 'binary', and
        their order if average is None. Labels present in the data can
        be excluded, for example to calculate a multiclass average ignoring
        a majority negative class, while labels not present in the data will
        result in 0 components in a macro average. For multilabel targets,
        labels are column indices. By default, all labels in y_true and
        y_pred are used in sorted order.
    average: This parameter is required for multiclass/multilabel targets.
        If None, the scores for each class are returned. Otherwise, this
        determines the type of averaging performed on the data:
            binary: Only report results for the class specified by pos

In [54]:
NB_EPOCHS = 2
BATCH_SIZE = 4

In [55]:
training_args = TrainingArguments(
    
    # output
    output_dir=RESULTS_FOLDER,          
    
    # params
    num_train_epochs=NB_EPOCHS,               # nb of epochs
    per_device_train_batch_size=BATCH_SIZE,   # batch size per device during training
    per_device_eval_batch_size=BATCH_SIZE,    # cf. paper Sun et al.
    learning_rate=1e-5,#2e-5,                 # cf. paper Sun et al.
#     warmup_steps=500,                         # number of warmup steps for learning rate scheduler
    warmup_ratio=0.1,                         # cf. paper Sun et al.
    weight_decay=0.01,                        # strength of weight decay
    
    # eval
    evaluation_strategy="steps",              # cf. paper Sun et al.
    eval_steps=20,                            # cf. paper Sun et al.
    
    # log
    logging_dir="/notebooks/KURI-BERT/results/tb_logs",  
    logging_strategy='steps',
    logging_steps=20,
    
    remove_unused_columns=False,
    
    # save
    save_strategy='steps',
    save_total_limit=2,
    # save_steps=20, # default 500
    load_best_model_at_end=True,              # cf. paper Sun et al.
    # metric_for_best_model='eval_loss' 
    metric_for_best_model='f1'
)

In [56]:
trainer = CustomTrainer( # Trainer(
    model=custom_model,
    args=training_args,
    tokenizer=tokenizer,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics
)

In [57]:
results = trainer.train()

***** Running training *****
  Num examples = 1088
  Num Epochs = 2
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 544


Step,Training Loss,Validation Loss


***** Running Evaluation *****
  Num examples = 273
  Batch size = 4


labels [[   1    2    2 ... -100 -100 -100]
 [   1    2    2 ... -100 -100 -100]
 [   0    1 -100 ... -100 -100 -100]
 ...
 [   0 -100 -100 ... -100 -100 -100]
 [   2    2    1 ... -100 -100 -100]
 [   2    1    2 ... -100 -100 -100]]
predzis: [2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2] 273


TypeError: only size-1 arrays can be converted to Python scalars