# BERT from scratch

We build a BERT model from scratch. We use the AG_NEWS dataset that is built-in in torchtext and use some of the tokenization tools from torchtext. A huggingface pipeline might have taken care of all of the pre-training steps but we wanted to get a more detailed understanding of the entire pipeline. 

We train a small-ish model for 5 epochs and test it in two ways. First, we look at its predictions on the first test batch to see if the predictions are plausible. Secondly, we investigate the test loss on a random network vs. our trained network. 

Our model is not perfect but we think it is sufficiently different from random chance that we can say it has learned something and our pipeline is functional. Our goal was not to reproduce or beat the state of the art but just to built a working pipeline so we stop there. 

Interestingly, a larger model learned to just predict the "[PAD]" token whereas a smaller model predicted mostly ",", ".", "a", "the", and the likes. Not sure where this comes from. Maybe we'd have to train it for longer.

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
print(torch.__version__)

1.10.0+cu102


# Data

In [2]:
import torchtext

In [3]:
dataset_train, dataset_test = torchtext.datasets.AG_NEWS()
print(len(dataset_train))
print(len(dataset_test))

120000
7600


In [4]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer = get_tokenizer('basic_english')
train_iter = torchtext.datasets.AG_NEWS(split='train')

def yield_tokens(data_iter):
    for _, text in data_iter:
        yield tokenizer(text)

vocab = build_vocab_from_iterator(yield_tokens(train_iter), min_freq=2,
                                  specials=["[PAD]", "[UNK]", "[CLS]", "[MASK]"])
vocab.set_default_index(vocab["[UNK]"])

In [5]:
#vocab(['here', 'is', 'an', 'example'])
vocab(["the"])
full_vocab = vocab.vocab.get_stoi().keys()
print(len(full_vocab))

53130


In [6]:
import random
import math
import copy

# copied from tutorial, added padding
def text_pipeline(x, max_len, percentage_masked=0.15):
    vocab_list = np.array(vocab(tokenizer(x)))[:max_len]
    k = len(vocab_list)
    missing_len = max_len - k
    missing_list = missing_len * vocab(["[PAD]"])
    # true labels
    labels = np.concatenate([copy.deepcopy(vocab_list), missing_list])
    # padding mask
    mask = np.random.choice([True, False], k, p=[percentage_masked, 1-percentage_masked])
    mask_idxs = np.arange(len(vocab_list))[mask]
    random.shuffle(mask_idxs)
    mask_80 = mask_idxs[:math.floor(0.8*len(mask_idxs))]
    mask_10 = mask_idxs[math.floor(0.8*len(mask_idxs)):math.floor(0.9*len(mask_idxs))]
    vocab_list[mask_80] = vocab(["[MASK]"])
    # fill remaining 10 percent with random words
    random_words = np.random.choice(list(full_vocab), len(mask_10), replace=True)
    vocab_list[mask_10] = vocab(list(random_words))
    
    return(np.concatenate([vocab_list, missing_list]), labels, np.concatenate([mask, np.zeros(missing_len, dtype=bool)]))

print(text_pipeline('He married Mabel Scott in 1890, but they soon separated. Unable to get an English divorce, in 1900, he became the first celebrity to get one in Nevada, and remarried there, but the divorce was invalid in England. In June 1901, he was arrested for bigamy, and was convicted before the House of Lords, the last time a peer was convicted by the Lords.', 50))

(array([5.2000e+01, 6.6170e+03, 4.7691e+04, 2.6430e+03, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 4.8000e+01, 7.0000e+01, 7.4700e+02,
       3.0000e+00, 4.0000e+00, 3.0000e+00, 7.0000e+00, 2.2500e+02,
       3.0000e+00, 1.8870e+03, 1.3152e+04, 6.0000e+00, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 5.2000e+01, 1.3610e+03, 5.0000e+00,
       5.0000e+01, 7.9570e+03, 7.0000e+00, 2.2500e+02, 6.4000e+01,
       1.0000e+01, 6.4220e+03, 6.0000e+00, 1.1000e+01, 1.0000e+00,
       2.3200e+02, 6.0000e+00, 4.8000e+01, 5.0000e+00, 1.3152e+04,
       3.8000e+01, 1.5491e+04, 3.0000e+00, 3.1800e+02, 4.0000e+00,
       1.0000e+01, 1.9190e+03, 3.0441e+04, 6.0000e+00, 5.2000e+01]), array([5.2000e+01, 6.6170e+03, 4.7691e+04, 2.6430e+03, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 4.8000e+01, 7.0000e+01, 7.4700e+02,
       1.0618e+04, 4.0000e+00, 4.3080e+03, 7.0000e+00, 2.2500e+02,
       3.3000e+01, 1.8870e+03, 1.3152e+04, 6.0000e+00, 1.0000e+01,
       1.0000e+00, 6.0000e+00, 5.2000e+01, 1.3610e+03, 5.00

In [7]:
# copied from tutorial, removed offsets
from torch.utils.data import DataLoader
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

def collate_batch(batch):
    label_list, text_list, mask_list = [], [], []
    for _, text in batch:
        input_, label_, mask_ = text_pipeline(text, max_len=100, percentage_masked=0.15)
        text_list.append(torch.tensor(input_, dtype=torch.int64))
        label_list.append(label_)
        mask_list.append(torch.tensor(mask_))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    text_list = torch.cat(text_list).view(len(label_list), -1)
    mask_list = torch.cat(mask_list).view(len(label_list), -1)
    return text_list.to(DEVICE), label_list.to(DEVICE), mask_list

from torchtext.data.functional import to_map_style_dataset
train_iter = to_map_style_dataset(dataset_train)  #Map-style dataset
BATCH_SIZE = 32
dataloader = DataLoader(list(train_iter), batch_size=BATCH_SIZE, collate_fn=collate_batch)

cuda


# Architecture

In [8]:
### build classifier transformer
import torch.nn.functional as F

class MyBERT(nn.Module):
    
    def __init__(self, embedding_dim, heads, seq_length, vocab_size, device, depth=5, num_classes=2):
        super().__init__()

        self.vocab_size = vocab_size
        self.token_emb = nn.Embedding(vocab_size, embedding_dim)
        self.pos_emb = nn.Embedding(seq_length, embedding_dim)
        self.num_heads = heads
        self.device = device

        # sequence of transformers
        tblocks = []
        for i in range(depth):
            tblocks.append(nn.TransformerEncoderLayer(d_model=embedding_dim,
                                                            nhead=self.num_heads, 
                                                            batch_first=True, dropout=0.1))
        self.tblocks = nn.Sequential(*tblocks)
        
        # final linear layer
        self.last_linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x, mask_idxs):
        # generate token embeddings
        tokens = self.token_emb(x)
        batch_size, token_size, embed_size = tokens.size()

        # generate position embeddings
        positions = torch.arange(token_size).to(self.device)
        positions = self.pos_emb(positions).expand(batch_size, token_size, embed_size)

        x = tokens + positions
        x = self.tblocks(x)

        out = self.last_linear(x)
        return out

# Training

In [9]:
# gpt-mini: n_layer=6, n_head=6, n_embd=192
# gpt-micro: n_layer=4, n_head=4, n_embd=128
# gpt-nano: n_layer=3, n_head=3, n_embd=48
# gpt-mini2: n_layers=6, n_head=16, n_embd=128
# gpt-mini3: n_layers=10, n_head=32, n_embd=256


EMBED_DIM = 48
NUM_HEADS = 6
NUM_LAYERS = 3
SEQ_LENGTH = 100
VOCAB_SIZE = len(full_vocab)

my_bert = MyBERT(embedding_dim=EMBED_DIM,
                 heads=NUM_HEADS, 
                 seq_length=100, 
                 vocab_size=len(full_vocab),
                 device=DEVICE,
                 depth=NUM_LAYERS).to(DEVICE)

optimizer = torch.optim.Adam(my_bert.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [10]:
# training
from tqdm.notebook import tqdm

num_epochs = 10

for epoch in range(num_epochs):
    print("epoch: ", epoch)
    training_loss = 0
    i = 0
    for (x, y, mask_idxs) in tqdm(dataloader):
        
        i += 1
        optimizer.zero_grad()
        out = my_bert(x, mask_idxs)
        out = torch.swapaxes(out, 1, 2)

        loss = criterion(out, y)
        training_loss += loss
        loss.backward()
        optimizer.step()
    
    print("training_loss: ", training_loss)

epoch:  0


  0%|          | 0/3750 [00:00<?, ?it/s]

  label_list = torch.tensor(label_list, dtype=torch.int64)
  label_list = torch.tensor(label_list, dtype=torch.int64)


training_loss:  tensor(3906.4558, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  1


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1912.2423, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  2


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1692.7605, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  3


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1599.1382, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  4


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1543.2677, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  5


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1505.9883, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  6


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1474.9716, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  7


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1448.3875, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  8


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1427.6262, device='cuda:0', grad_fn=<AddBackward0>)
epoch:  9


  0%|          | 0/3750 [00:00<?, ?it/s]

training_loss:  tensor(1408.1722, device='cuda:0', grad_fn=<AddBackward0>)


In [11]:
# save model
BERT_PATH = "parameters/my_bert_mini3.pth"
print("saving model at: {}".format(BERT_PATH))
torch.save(my_bert.state_dict(), BERT_PATH)

saving model at: parameters/my_bert_mini3.pth


# Testing selected outputs

In [12]:
# load model
BERT_PATH = "parameters/my_bert_mini3.pth"
my_bert = MyBERT(embedding_dim=EMBED_DIM,
                 heads=NUM_HEADS, 
                 seq_length=100, 
                 vocab_size=len(full_vocab),
                 device=DEVICE,
                 depth=NUM_LAYERS).to(DEVICE)

print("loading model from: {}".format(BERT_PATH))
my_bert.load_state_dict(torch.load(BERT_PATH))

loading model from: parameters/my_bert_mini3.pth


<All keys matched successfully>

In [13]:
# just one sentence 

test_sentence, test_labels, test_masks = text_pipeline('He married Mabel Scott in 1890, but they soon separated. \
Unable to get an English divorce, in 1900, he became the first celebrity to get one in Nevada, \
and remarried there, but the divorce was invalid in England. In June 1901, he was arrested for bigamy, \
and was convicted before the House of Lords, the last time a peer was convicted by the Lords.', 50)

def print_outputs(test_sentence, test_labels, test_masks):
    test_sentence = torch.tensor(test_sentence).long().cuda().view(1,-1)
    #print(test_sentence)
    test_mask_idxs = torch.ones_like(test_sentence).bool().cuda().view(1, -1)
    #test_mask_idxs[:,40:] = False
    #print(test_masks)
    out = my_bert(test_sentence, test_mask_idxs)
    #print(out.size())
    out_sm = torch.softmax(out, dim=-1).argmax(dim=-1)
    print(out)
    #print(len(out_sm[0]))
    
    pprint(out_sm[0], test_labels, test_masks)

    
def pprint(predictions, labels, masks):
    li = ['_'.join([vocab.vocab.get_itos()[pred], vocab.vocab.get_itos()[label]]) 
                    if mask 
                    else vocab.vocab.get_itos()[pred]
                    for pred, label, mask in zip(predictions, torch.tensor(labels).long(), masks)]
    li = [x for x in li if x != '[PAD]']
    print(' '.join(li))

#print([vocab.vocab.get_itos()[i] for i in test_sentence])
#print([vocab.vocab.get_itos()[i] for i in out_sm[0]])
#print([vocab.vocab.get_itos()[i] for i in torch.tensor(test_labels).long()])

print_outputs(test_sentence, test_labels, test_masks)

tensor([[[-11.8586,   6.7565, -17.0153,  ...,  -9.8805, -11.0341, -11.1836],
         [ -1.5625,   3.8778, -10.2055,  ...,  -6.5444,  -6.1065,  -5.9842],
         [-10.5377,   4.4939, -13.0039,  ..., -10.3370,  -6.1239, -12.6618],
         ...,
         [ -9.7811,   2.6191, -17.0421,  ...,  -7.3072,  -5.5557, -16.9369],
         [ -4.1034,   4.1237, -10.6320,  ...,  -7.3408,  -5.9452,  -4.7843],
         [ -8.0328,   4.1775, -16.4952,  ...,  -9.2391, -11.7229,  -1.8939]]],
       device='cuda:0', grad_fn=<AddBackward0>)
he ,_married mabel ,_scott in [UNK] ,_, but they soon separated ,_. unable_unable to get an english nips_divorce , in [UNK] , he became the first celebrity to get one in nevada , and [UNK] the_there , but the the_divorce was invalid in england . in june 1901 ,_, he


In [14]:
### test the first batch
test_iter = to_map_style_dataset(dataset_test)  #Map-style dataset
BATCH_SIZE = 32
testloader = DataLoader(list(test_iter), batch_size=BATCH_SIZE, collate_fn=collate_batch)

In [15]:
# test 

for i, (x, y, mask_idxs) in tqdm(enumerate(testloader), total=len(dataset_test)//BATCH_SIZE):
    
    #optimizer.zero_grad()
    out = my_bert(x, mask_idxs)

    out_sm = torch.softmax(out, dim=-1).argmax(dim=-1)
    
    for j in range(BATCH_SIZE):
        print(j)
        pprint(out_sm[j], y[j], mask_idxs[j])
    
    break

  0%|          | 0/237 [00:00<?, ?it/s]

  label_list = torch.tensor(label_list, dtype=torch.int64)
  for pred, label, mask in zip(predictions, torch.tensor(labels).long(), masks)]


0
fears '_for t n pension after talks unions jail_representing workers at turner the_newall say they are_are ' disappointed ' after talks the_with stricken parent firm federal the_mogul .
1
the race is on second private ,_team sets launch date for human spaceflight ( space . com ) space . com - toronto ,_, canada space_-- a [UNK] of ,_rocketeers competing for the #36 10 million ansari x prize , a contest [UNK] funded suborbital space flight , has officially announced the [UNK] date for its manned rocket .
2
ky ._. company wins grant to study peptides ( ap_ap ) ap - a company founded ritz_by a chemistry researcher at the university of_of louisville won a grant to develop the_a method the_of producing better peptides ,_, which are short chains of amino acids , the ,_building blocks of proteins .
3
prediction ,_unit helps forecast wildfires ( ap_ap ) ap - burden_it ' s ,_barely dawn when mike fitzpatrick starts his shift with a blur of colorful the_maps , figures and endless charts , but 

mighty to_ortiz makes sure_sure sox can ,_rest easy just imagine what david ortiz could do on a ,_good ,_night ' s rest . ortiz spent ,_the night before last with his baby boy , d ' angelo the_, the_who is barely 1 month old . he had planned on attending the red sox ' family day at ,_fenway ,_park yesterday morning , but he had to ,_sleep in . the_after all , ortiz had a son at home , and he ._. ._. .
30
they ' ve to_caught his eye in quot helping themselves , quot the_ricky bryant , [UNK] [UNK] dance_, the_michael jennings , and_and the_david the_patten did nothing friday the_night to make bill belichick ' s decision on what to do with his ,_receivers any easier .
31
indians mount charge the cleveland indians pulled within one game of the al central lead by beating the minnesota twins , 7-1 , saturday night with the_home runs by_by travis hafner and victor martinez .


## test against random loss

In [16]:
### test the model on a random init version of the network
random_network = MyBERT(embedding_dim=EMBED_DIM, 
               heads=NUM_HEADS, 
               seq_length=SEQ_LENGTH,
               vocab_size=VOCAB_SIZE,
               device=DEVICE,
               depth=NUM_LAYERS).to(DEVICE)

random_network.eval()

random_test_loss = 0
for i, (x, y, masks_idxs) in tqdm(enumerate(testloader), total=len(dataset_test)//BATCH_SIZE):

        #if i > 50: break
        x, y, masks_idxs = x.to(DEVICE), y.to(DEVICE), masks_idxs.to(DEVICE)

        out = random_network(x, masks_idxs)
        out = torch.swapaxes(out, 1, 2)
        
        loss = criterion(out, y)
        random_test_loss += loss.item()
        
print("test loss from random network: {:.03f}".format(random_test_loss))

  0%|          | 0/237 [00:00<?, ?it/s]

  label_list = torch.tensor(label_list, dtype=torch.int64)


test loss from random network: 2711.031


In [17]:
### load model

trained_model = MyBERT(embedding_dim=EMBED_DIM, 
               heads=NUM_HEADS, 
               seq_length=SEQ_LENGTH,
               vocab_size=VOCAB_SIZE,
               device=DEVICE,
               depth=NUM_LAYERS).to(DEVICE)

trained_model.load_state_dict(torch.load(BERT_PATH))
trained_model.eval()

test_loss = 0
for i, (x, y, masks_idxs) in tqdm(enumerate(testloader), total=len(dataset_test)//BATCH_SIZE):

        #if i > 50: break
        x, y, masks_idxs = x.to(DEVICE), y.to(DEVICE), masks_idxs.to(DEVICE)

        out = trained_model(x, masks_idxs)
        out = torch.swapaxes(out, 1, 2)
        
        loss = criterion(out, y)
        test_loss += loss.item()
        
print("test loss from random network: {:.03f}".format(test_loss))

  0%|          | 0/237 [00:00<?, ?it/s]

  label_list = torch.tensor(label_list, dtype=torch.int64)


test loss from random network: 87.116
