In [42]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import typing
from transformers import BertTokenizer, BertTokenizerFast, BertModel

In [51]:
import webbrowser as wb
ref_list = [
    "https://huggingface.co/docs/transformers/model_doc/bert",
    "https://huggingface.co/docs/transformers/v4.24.0/en/model_doc/bert#transformers.BertConfig",
    "https://huggingface.co/docs/transformers/v4.24.0/en/model_doc/bert#transformers.BertTokenizer",
    "https://arxiv.org/pdf/1810.04805.pdf"
]
for url in ref_list:
    wb.open(url)

def subdict(d, keys):
    return {key:d[key] for key in keys}

In [44]:
class MNLIDataset():
    def __init__(self, root, type_map, tokenizer, batch_size, embedding_size):
        self.idx = 0
        self.type_map = type_map
        self.bs = batch_size
        self.es = embedding_size
        self.tokenizer = tokenizer
        self.df = pd.read_json(root, lines=True)
        self.df = self.df[['sentence1', 'sentence2', 'gold_label']]
        self.df['gold_label'] = self.df['gold_label'].apply(lambda x: self.type_mapper(x))

    def type_mapper(self, text):
        return self.type_map[text]    

    def __len__(self):
        return self.df.shape[0]
    
    def set_idx(self, idx):
        self.idx = idx

    def get_batch(self):
        # cycle through batches and restart if idx is too large
        start = 0
        if self.idx + self.bs < len(self):
            start = self.idx
        else:
            self.idx = 0
        end = min(start + self.bs, len(self))
        data = self.tokenizer(self.df['sentence1'].iloc[start:end].tolist(), 
                              self.df['sentence2'].iloc[start:end].tolist(), 
                              return_tensors='pt',
                              padding=True
                              )
        self.idx += self.bs
        batch_input = subdict(data, ['input_ids', 'attention_mask', 'token_type_ids'])
        batch_labels = torch.tensor(list(self.df['gold_label'].iloc[start:end]))
        return batch_input, batch_labels 
        # return {'input_ids': torch.concat([ data['input_ids'], torch.zeros(self.bs, self.es - data['input_ids'].shape[1]) ], dim=1).to(dtype=torch.int)}

train_data = MNLIDataset('mnli/multinli_1.0/train_debug.json', 
                         {'neutral':0, 'entailment':1, 'contradiction':2},
                         BertTokenizerFast.from_pretrained('bert-base-uncased'),
                         batch_size=32,
                         embedding_size=512
                         )

In [45]:
# Use bert for classification by sending [cls] [tokenization of sentence1] [sep] [tokenization of sentence2]
# into bert and then using embedding of [cls] as input to classifier
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
print(f"Device: {device}")

class BertClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, 3) # MNLI has 3 classes
        
    def forward(self, input):
        x = self.bert(**input)
        x = x.last_hidden_state[:, 0, :].squeeze() # pick out [CLS] final embedding
        logits = self.fc(x)
        return logits

model = BertClassifier().to(device)

Device: cpu


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [46]:
frozen_layers = 10
for layer in range(frozen_layers):
    for param in model.bert.encoder.layer[layer].parameters():
        param.requires_grad = False

In [47]:
input, labels = train_data.get_batch()
input, labels = {key:input[key].to(device) for key in input}, labels.to(device)
output = model(input)
print(f"{output.shape}\n{labels.shape}")
print(output)
print(labels)

torch.Size([32, 3])
torch.Size([32])
tensor([[ 0.4482, -0.2323,  0.2887],
        [ 0.5486, -0.3704, -0.0165],
        [ 0.3745, -0.1011,  0.1391],
        [ 0.3059, -0.3736,  0.1182],
        [ 0.5399, -0.3053,  0.0831],
        [ 0.4903, -0.2889,  0.2261],
        [ 0.6127, -0.2593,  0.0258],
        [ 0.5393, -0.4015,  0.2033],
        [ 0.2678, -0.2319, -0.0842],
        [ 0.5263, -0.0207,  0.3270],
        [ 0.2517,  0.0069,  0.1417],
        [ 0.1881, -0.4228, -0.1700],
        [ 0.4877, -0.5912, -0.0872],
        [ 0.3928, -0.0749, -0.2020],
        [ 0.5052, -0.3313,  0.1834],
        [ 0.5916, -0.3129,  0.1199],
        [ 0.3025, -0.2111, -0.0886],
        [ 0.0874, -0.3144,  0.2670],
        [ 0.4075, -0.3485,  0.1285],
        [ 0.1913, -0.2660, -0.0771],
        [ 0.3413, -0.1709, -0.1335],
        [ 0.4636, -0.2988, -0.0686],
        [ 0.2597, -0.1512, -0.0459],
        [ 0.4030, -0.2846, -0.0413],
        [ 0.5565, -0.2495, -0.0131],
        [ 0.3969, -0.3045,  0.1001],
 

In [48]:
loss_fn = nn.CrossEntropyLoss()
preds = torch.argmax(output, dim=1)
print(preds)

loss = loss_fn(output, labels)
print(loss)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
tensor(1.1935, grad_fn=<NllLossBackward0>)


In [49]:
optim = torch.optim.Adam(model.parameters())
losses = []
def train_epoch(model, loss_fn, optim, train_data):
    bs = train_data.bs
    train_example_num = len(train_data)
    batch_num = train_example_num // bs if train_example_num % bs > 0 else train_example_num // bs + 1
    
    for i in range(batch_num):
        input, labels = train_data.get_batch()
        input, labels = {key:input[key].to(device) for key in input}, labels.to(device)
        output = model(input)
        
        loss = loss_fn(output, labels)
        if i % 2 == 0:
            losses.append(loss.item())
            print(f"loss={loss.item()}.2f [{i * bs} / {train_example_num}]")

        optim.zero_grad()
        loss.backward()
        optim.step()         

In [50]:
epoch_num = 20
for epochs in range(epoch_num):
    train_epoch(model, loss_fn, optim, train_data)
    print(f"End of epoch {epochs}")

loss=1.124358057975769.2f [0 / 200]
loss=2.788536310195923.2f [64 / 200]
loss=1.2627285718917847.2f [128 / 200]
End of epoch 0
loss=0.8348004221916199.2f [0 / 200]
loss=0.9568321704864502.2f [64 / 200]
loss=0.7655072212219238.2f [128 / 200]
End of epoch 1
loss=0.27687278389930725.2f [0 / 200]
loss=0.4760887324810028.2f [64 / 200]
loss=0.195520281791687.2f [128 / 200]
End of epoch 2
loss=0.0427311547100544.2f [0 / 200]
loss=0.30090799927711487.2f [64 / 200]
loss=0.024307992309331894.2f [128 / 200]
End of epoch 3
loss=0.16933958232402802.2f [0 / 200]
loss=0.30625542998313904.2f [64 / 200]
loss=0.21990647912025452.2f [128 / 200]
End of epoch 4
loss=0.11982423067092896.2f [0 / 200]
loss=0.045377954840660095.2f [64 / 200]
loss=0.03413793444633484.2f [128 / 200]
End of epoch 5
loss=0.035563837736845016.2f [0 / 200]
loss=0.010574637912213802.2f [64 / 200]
loss=0.010204879567027092.2f [128 / 200]
End of epoch 6
loss=0.0032322166953235865.2f [0 / 200]
loss=0.043013814836740494.2f [64 / 200]
los

In [None]:
# It's fitting!