In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.profiler import profile, record_function, ProfilerActivity
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertTokenizerFast, BertModel
import gc

In [5]:
# open reference docs
import webbrowser as wb
ref_list = [
    "https://huggingface.co/docs/transformers/model_doc/bert",
    "https://huggingface.co/docs/transformers/v4.24.0/en/model_doc/bert#transformers.BertConfig",
    "https://huggingface.co/docs/transformers/v4.24.0/en/model_doc/bert#transformers.BertTokenizer",
    "https://arxiv.org/pdf/1810.04805.pdf"
]
for url in ref_list:
    wb.open(url)

In [2]:
def subdict(d, keys):
    return {key:d[key] for key in keys}

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

class MNLIDataset():
    def __init__(self, root, type_map, tokenizer, batch_size, embedding_size):
        self.idx = 0
        self.type_map = type_map
        self.bs = batch_size
        self.es = embedding_size
        self.tokenizer = tokenizer
        
        print("Loading data...")
        with open(root, 'r') as f:
            self.df = pd.read_json(f, lines=True)
        self.df = self.df[['sentence1', 'sentence2', 'gold_label']]
       
        target_gold_label_cats = set(['neutral', 'entailment', 'contradiction'])
        self.df = self.df.loc[self.df['gold_label'].isin(target_gold_label_cats)]
        self.df['gold_label'] = self.df['gold_label'].apply(lambda x: self.type_mapper(x))
        
        print("Tokenizing data...")
        # Generate list of embeddings in format: [CLS] * [tokens of sent1] * [SEP] * [tokens of sent2] * [SEP]
        # Returns list of dicts of tensors 'input_ids', 'attention_mask', and 'token_type_ids'
        self.data = [self.tokenizer(self.df['sentence1'].iloc[i*self.bs:(i+1)*self.bs].tolist(), 
                                   self.df['sentence2'].iloc[i*self.bs:(i+1)*self.bs].tolist(), 
                                   return_tensors='pt',
                                   padding=True,
                                   truncation=True
                                   ) for i in range(self.df.shape[0] // self.bs) ]
        self.labels = torch.tensor(list(self.df['gold_label']))
        del self.df
        print("Done.")

    def get_batch_num(self):
        example_num = len(self)
        batch_num = example_num // self.bs + 1 if example_num % self.bs > 0 else example_num // self.bs
        return batch_num

    # transform class name into integer
    def type_mapper(self, text):
        return self.type_map[text]    

    def __len__(self):
        return len(self.data) * self.bs
    
    def set_idx(self, idx):
        self.idx = idx

    def get_batch(self, device):
        # cycle through batches and restart if idx is too large
        if self.idx >= len(self.data): 
            self.idx = 0
        data = {k:torch.tensor(self.data[self.idx][k], device=device) for k in self.data[self.idx]}
        batch_input = subdict(data, ['input_ids', 'attention_mask', 'token_type_ids'])
        batch_labels = torch.tensor(self.labels[self.idx*self.bs : (self.idx+1)*self.bs], device=device)
        self.idx += 1
        return batch_input, batch_labels

Device: cuda


In [3]:
bs = 8 # about the best I can do with 8 Gb VRAM and no frozen layers
debug_data = MNLIDataset(root='data/multinli_1.0_train_debug.json', 
                         type_map={'neutral':0, 'entailment':1, 'contradiction':2},
                         tokenizer=BertTokenizerFast.from_pretrained('bert-base-uncased'),
                         batch_size=bs,
                         embedding_size=512
                         )

train_data = MNLIDataset(root='data/multinli_1.0_train.jsonl', 
                         type_map={'neutral':0, 'entailment':1, 'contradiction':2},
                         tokenizer=BertTokenizerFast.from_pretrained('bert-base-uncased'),
                         batch_size=bs,
                         embedding_size=512
                         )

dev_data = MNLIDataset(root='data/multinli_1.0_dev_matched.jsonl', 
                       type_map={'neutral':0, 'entailment':1, 'contradiction':2},
                       tokenizer=BertTokenizerFast.from_pretrained('bert-base-uncased'),
                       batch_size=bs,
                       embedding_size=512
                       )


Loading data...
Tokenizing data...
Done.
Loading data...
Tokenizing data...
Done.
Loading data...
Tokenizing data...
Done.


In [4]:
# Pass tokenization of sentences into bert and map the final embedding of the [CLS] token into 3D vector for 3-class classification
class BertClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, 3) # MNLI has 3 classes
    
    def forward(self, input):
        x = self.bert(**input)
        x = x.last_hidden_state[:, 0, :].squeeze() # pick out [CLS] final embedding
        logits = self.fc(x)
        return logits

model = BertClassifier().to(device)
model.load_state_dict(torch.load('data/bertmnli.pt'))

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<All keys matched successfully>

In [6]:
# look at model architecture
print(model.bert)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [5]:
# Does it work with initial layers frozen? Not really...
for param in model.bert.embeddings.parameters():
    param.detach_()
frozen_layers = 8
for layer in range(frozen_layers):
    for param in model.bert.encoder.layer[layer].parameters():
        param.detach_() # detaching all initial layers increases batch size on my single gpu by cutting off comp graph, saving memory

In [5]:
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr = 2e-5)
losses = []

def train_epoch(model, loss_fn, optim, train_data, debug=False):
    bs = train_data.bs
    batch_num = train_data.get_batch_num()
    for i in range(batch_num):
        if debug:
            print(f"memory allocated before get_batch on round {i}: {torch.cuda.memory_allocated() // 1024 ** 2}")
            with profile(activities=[
                ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
                with record_function("model_inference"):
                    input, labels = train_data.get_batch(device)
                    print(input['attention_mask'].shape)
                    print(input['input_ids'][0:10])
                    output = model(input) # why does output have larger memory footprint on full train data vs debug data?
                    loss = loss_fn(output, labels)
            print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
            prof.export_chrome_trace("trace2.json")
            print(f"memory allocated after output on round {i}: {torch.cuda.memory_allocated() // 1024 ** 2}")
        
        else:
            input, labels = train_data.get_batch(device)
            output = model(input)
            loss = loss_fn(output, labels)

            if i % 500 == 0:
                losses.append(loss.item())
                print(f"loss={loss.item()} [{i * bs} / {batch_num * bs}]")
            
            optim.zero_grad()
            loss.backward()
            optim.step()     

In [10]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
epoch_num = 5
for epochs in range(epoch_num):
    train_epoch(model, loss_fn, optim, train_data)
    print(f"End of epoch {epochs}")

In [9]:
# torch.save(model.state_dict(), "bertmnli.pt")

In [6]:
def test(model, loss_fn, data, losses):
    with torch.no_grad():
        batch_num = data.get_batch_num()
        
        for i in range(batch_num):
            input, labels = data.get_batch(device)
            preds = model(input)
            loss = loss_fn(preds, labels)
            losses.append(loss.item())
            if i % 100 == 0:
                print(f"loss={loss.item()} [{i * bs} / {batch_num * bs}]")
            
        print(f"Avg loss: {sum(losses) / batch_num}")

In [7]:
test(model, loss_fn, dev_data, [])

  data = {k:torch.tensor(self.data[self.idx][k], device=device) for k in self.data[self.idx]}
  batch_labels = torch.tensor(self.labels[self.idx*self.bs : (self.idx+1)*self.bs], device=device)


loss=0.20467762649059296 [0 / 9808]
loss=0.0006142045021988451 [800 / 9808]
loss=1.7694125175476074 [1600 / 9808]
loss=0.584679365158081 [2400 / 9808]
loss=0.8866902589797974 [3200 / 9808]
loss=1.0042858123779297 [4000 / 9808]
loss=1.0168482065200806 [4800 / 9808]
loss=1.634853720664978 [5600 / 9808]
loss=0.054585762321949005 [6400 / 9808]
loss=0.01108003593981266 [7200 / 9808]
loss=1.797624945640564 [8000 / 9808]
loss=0.3181334435939789 [8800 / 9808]
loss=0.5450186133384705 [9600 / 9808]
Avg loss: 0.7827006019944115


In [None]:
epoch_num = 20
for epochs in range(epoch_num):
    train_epoch(model, loss_fn, optim, debug_data)
    print(f"End of epoch {epochs}")