# Implementation of the Kuribayashi BERT minus model

## libraries

In [1]:
!pip install transformers
!pip install ipywidgets
!pip install IProgress
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com


In [2]:
import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertModel, BertForSequenceClassification
from transformers import BatchEncoding

import torch
import torch.nn as nn

import datasets

## tokenizer

In [3]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

## data pre-processing

In [4]:
batch_size = 2 
# for now. will be globalized. 

#### batch 1

In [5]:
paragraph_1 = """
Some people may argue that children will be more material,
neglect their study for earning money or be exploited by the
employers. However, if children get good care and
instructions from their parents, they can take advantages of the
work to learn valuable things and avoid going in a wrong way.
"""

In [6]:
paragraph_2 = """
Firstly, it is a fact that tourists from different cultures will probably cause changes to the 
cultural identity of the tourist destinations. For example, in the Vietnam War, many 
American soldiers came to Thailand for a break and involved in sexual and drug activities. 
"""

In [7]:
batch_texts = []
batch_texts.append(paragraph_1)
batch_texts.append(paragraph_2)

In [8]:
# define spans
# these spans are inclusive

span_1 = []

am_1_span = (0, 4)
ac_1_span = (5, 23)

am_2_span = (24, 25)
ac_2_span = (26, 55)

span_1.append(am_1_span)
span_1.append(ac_1_span)
span_1.append(am_2_span)
span_1.append(ac_2_span)

In [9]:
span_2 = []

am_1_span = (0, 6)
ac_1_span = (7, 23)

am_2_span = (24, 26)
ac_2_span = (27, 48)

span_2.append(am_1_span)
span_2.append(ac_1_span)
span_2.append(am_2_span)
span_2.append(ac_2_span)
# span_2.append((34, 189))

In [10]:
batch_spans = []
batch_spans.append(span_1)
batch_spans.append(span_2)

In [11]:
batch_spans

[[(0, 4), (5, 23), (24, 25), (26, 55)], [(0, 6), (7, 23), (24, 26), (27, 48)]]

In [12]:
batch_texts

['\nSome people may argue that children will be more material,\nneglect their study for earning money or be exploited by the\nemployers. However, if children get good care and\ninstructions from their parents, they can take advantages of the\nwork to learn valuable things and avoid going in a wrong way.\n',
 '\nFirstly, it is a fact that tourists from different cultures will probably cause changes to the \ncultural identity of the tourist destinations. For example, in the Vietnam War, many \nAmerican soldiers came to Thailand for a break and involved in sexual and drug activities. \n']

In [13]:
labels_1 = [1,2]
labels_2 = [1,0]
batch_labels = [labels_1, labels_2]

## dataset

In [14]:
dataset_d = {
    'texts' : batch_texts,
    'spans' : batch_spans,
    'labels' : batch_labels
}

In [15]:
dataset = datasets.Dataset.from_dict(dataset_d)

In [16]:
dataset

Dataset({
    features: ['texts', 'spans', 'labels'],
    num_rows: 2
})

In [17]:
dataset['texts']

['\nSome people may argue that children will be more material,\nneglect their study for earning money or be exploited by the\nemployers. However, if children get good care and\ninstructions from their parents, they can take advantages of the\nwork to learn valuable things and avoid going in a wrong way.\n',
 '\nFirstly, it is a fact that tourists from different cultures will probably cause changes to the \ncultural identity of the tourist destinations. For example, in the Vietnam War, many \nAmerican soldiers came to Thailand for a break and involved in sexual and drug activities. \n']

In [18]:
dataset['spans']

[[[0, 4], [5, 23], [24, 25], [26, 55]], [[0, 6], [7, 23], [24, 26], [27, 48]]]

In [19]:
dataset['labels']

[[1, 2], [1, 0]]

In [20]:
def tokenize(dataset):
    
    tokenized_text = tokenizer(dataset['texts'], padding=True, return_tensors="pt")
    max_length = max([len(l) for l in dataset['spans']])
    span_ll = []
    
    for idx, span in enumerate(dataset['spans']):
        
        tmp_l = dataset['spans'][idx] + (max_length - len(span)) * [[0,0]]
        span_ll.append(tmp_l)

    tokenized_text['spans'] = torch.tensor(span_ll)
    tokenized_text['labels'] = torch.tensor(dataset['labels'])
    return tokenized_text

In [21]:
tokenized_input = tokenize(dataset)

## batching

In [22]:
batch = tokenize(dataset)

In [23]:
batch_tokenized = BatchEncoding({k:batch[k] for k in ('input_ids','token_type_ids','attention_mask')})

In [24]:
batch_spans = BatchEncoding({'spans':batch['spans']})

In [25]:
batch_labels = BatchEncoding({'labels':batch['labels']})

In [30]:
batch_spans['spans'].shape

torch.Size([2, 4, 2])

## span representation function

In [26]:
def get_span_representations(outputs, spans):

    batch_size = spans.shape[0]
    nr_span_indices = spans.shape[1]
    
    idx_l_ams = range(0, spans.shape[1], 2) 
    idx_l_acs = range(1, spans.shape[1], 2)  
    
    am_spans = spans[:, idx_l_ams, :] + 1
    ac_spans = spans[:, idx_l_acs, :] + 1
    
    am_spans = am_spans.flatten(start_dim=1)
    ac_spans = ac_spans.flatten(start_dim=1)
    
    outputs_am = outputs[:,am_spans,:]
    outputs_am = torch.cat([outputs_am[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_am = outputs_am.reshape(batch_size, nr_span_indices, -1)
    
    outputs_am_r = torch.cat([outputs_am[:,i,:] - outputs_am[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_r = outputs_am_r.reshape(batch_size, -1, 768)
    
    outputs_am_l = torch.cat([outputs_am[:,i+1,:] - outputs_am[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_am_l = outputs_am_r.reshape(batch_size, -1, 768)
    
    am_minus_representations = torch.cat([outputs_am_r, outputs_am_l], dim=-1)
    
    outputs_ac = outputs[:,ac_spans,:]
    outputs_ac = torch.cat([outputs_ac[i,i,:,:] for i in range(batch_size)], dim=0)
    outputs_ac = outputs_ac.reshape(batch_size, nr_span_indices, -1)
    
    outputs_ac_r = torch.cat([outputs_ac[:,i,:] - outputs_ac[:,i+1,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_r = outputs_ac_r.reshape(batch_size, -1, 768)
    
    outputs_ac_l = torch.cat([outputs_ac[:,i+1,:] - outputs_ac[:,i,:] for i in range(0, nr_span_indices, 2)], dim=1)
    outputs_ac_l = outputs_ac_l.reshape(batch_size, -1, 768)
    
    ac_minus_representations = torch.cat([outputs_ac_r, outputs_ac_l], dim=-1)
    
    return am_minus_representations, ac_minus_representations                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            

## custom BERT model

In [27]:
class CustomBERTKuri(nn.Module):

    def __init__(self, first_model, model_am, model_ac, nr_classes):
        
        super(CustomBERTKuri, self).__init__()
        
        self.first_model = first_model
        
        self.intermediate_linear_am = nn.Linear(1536, 768)
        self.intermediate_linear_ac = nn.Linear(1536, 768)        
        
        self.model_am = model_am
        self.model_ac = model_ac
        
        self.nr_classes = nr_classes
                
        self.fc = nn.Linear(self.model_am.config.hidden_size + self.model_ac.config.hidden_size, self.nr_classes)        

    def forward(self, batch_tokenized, batch_spans):
        
        outputs = self.first_model(**batch_tokenized)[0]
        spans = batch_spans['spans']
        am_minus_representations, ac_minus_representations = get_span_representations(outputs, spans)

        am_minus_representations = self.intermediate_linear_am(am_minus_representations)
        ac_minus_representations = self.intermediate_linear_am(ac_minus_representations)

        output_model_am = self.model_am(inputs_embeds = am_minus_representations)[0]
        output_model_ac = self.model_ac(inputs_embeds = ac_minus_representations)[0]

        adu_representations = torch.cat([output_model_am, output_model_ac], dim=-1)

        output = self.fc(adu_representations)
        
        output = output.reshape(-1, self.nr_classes)
            
        return output

## Run

In [28]:
first_model = BertModel(BertConfig.from_pretrained("bert-base-cased"))

In [29]:
model_am = BertModel(BertConfig.from_pretrained("bert-base-cased"))

In [30]:
model_ac = BertModel(BertConfig.from_pretrained("bert-base-cased"))

In [31]:
custom_model = CustomBERTKuri(first_model, model_am, model_ac, 3)

In [32]:
model_output = custom_model(batch_tokenized, batch_spans)

In [33]:
model_output.shape

torch.Size([4, 3])

In [34]:
labels = batch_labels['labels']

In [35]:
labels = labels.flatten()

In [36]:
criterion = torch.nn.CrossEntropyLoss()

In [37]:
optimizer = torch.optim.Adam(custom_model.parameters(), lr=0.01)

In [38]:
loss = criterion(model_output, labels)

In [39]:
optimizer.zero_grad()

In [40]:
loss.backward() 

In [41]:
optimizer.step()