### Demonstrates simple usage for pretraining Bert; 
* Using Masked Language Modelling and Next Sentence Prediction as pretraining tasks;
* I also use the the top 2 books from "yesterday" from project Gutenberg;
* Details of how I did the data processing can be found in the bert_data_processing.py file;
* Details of implementaion of BertModel can be foun in the bert.py file;

In [1]:
import torch
from torch import nn
import bert
import bert_data_processing as bdp
torch.manual_seed(42)

<torch._C.Generator at 0x7f617003c5b0>

### Get the data;

In [2]:
batch_size, max_len = 512, 64
train_iter, vocab = bdp.get_gutenberg_loader_and_vocab(batch_size, max_len, 
                                                       "gutenberg_books.txt", 
                                                       num_books=2,
                                                       truncate=True, min_freq=5)

A file with that name already exists, if truncate is True I will overwrite it. Continue [y/n]:y
downloading 2 books...


### Instantiate a BertModel instance;
* Uses Next Sentence Prediction as well as Masked Language Modelling as pretraining tasks;
* Hence, there will be some CrossEntropy losses involved later in the notebook;
* This initialisation is somewhat motivated by bert_small;

In [3]:
vocab_size, hidden_dim, ffn_hidden, num_heads = len(vocab), 768, 1024, 4
norm_dim, ffn_input, num_layers, dropout, with_bias = [768], 768, 2, 0.2, True

bert_model = bert.BertModel(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_heads,
                            norm_dim, ffn_input, ffn_hidden, num_layers, vocab_size, pos_encoding_size=1000,
                            mlm_input=hidden_dim, mlm_hiddens=2*hidden_dim, nsp_input=hidden_dim,
                            nsp_hidden=2*hidden_dim, dropout=dropout, with_bias=with_bias)

### Get two losses for NSP and MLM;
* Since the sequences in the dataset have been padded to be of the same length, I will use weights when calculating the loss from MLM.
* A weight of *0* for <pad\> tokens and *1* for real tokens;
* The reduction for the MLM loss is *none*, therefore, and is averaged over the weights manually;
* For the NSP task the reduction is the default one, i.e. *reduction="mean"*;

In [4]:
loss_mlm = nn.CrossEntropyLoss(reduction="none")
loss_nsp = nn.CrossEntropyLoss()

### I choose to optimise with Adam;

In [5]:
adam_optim = torch.optim.Adam(bert_model.parameters())

In [6]:
# check the general format of the output from the BertModel instance forward call;
with torch.no_grad():
    for (tokens, segments, attention_masks, masked_positions, weights_for_masks,
         original_labels_for_masks, nsp_labels) in train_iter:
        encodings, mlm_preds, nsp_preds = bert_model(tokens, segments, attention_masks, 
                                                     masked_positions)
        break

print(f"shape of encoding: {encodings.shape} --> (batch_size, seq_len, embed_size)")
print(f"shape of mlm_preds: {mlm_preds.shape} --> (batch_size, num_masks, vocab_size)")
print(f"shape of nsp_preds: {nsp_preds.shape} --> (batch_size, num_classes_to_predict)")

shape of encoding: torch.Size([512, 64, 768]) --> (batch_size, seq_len, embed_size)
shape of mlm_preds: torch.Size([512, 10, 1816]) --> (batch_size, num_masks, vocab_size)
shape of nsp_preds: torch.Size([512, 2]) --> (batch_size, num_classes_to_predict)


### Combine the losses from the pretraining tasks;

In [7]:
def loss_per_batch(bert_model, loss_mlm, loss_nsp, vocab_size, tokens, segments,
                         attention_masks, masked_positions, weights_for_masks, 
                         original_labels_for_masks, nsp_labels):
    
    # effectively the forward pass;
    encodings, mlm_preds, nsp_preds = bert_model(tokens, segments,
                                          attention_masks, masked_positions)
    # Now get loss from MLM task;
    # it is important that the loss reduction is none, so that I can customize it a bit
    # by multiplying by weights so that I don't count the loss from <pad> tokens
    # when doing the MLM model;
    mlm_loss = loss_mlm(mlm_preds.reshape(-1, vocab_size), 
                    original_labels_for_masks.reshape(-1)) * weights_for_masks.reshape(-1)
    
    # now I can average the MLM loss over the weights;
    mlm_loss = mlm_loss.sum() / (weights_for_masks.sum() + 1e-9)
    
    # Now get the next sentence prediction loss;
    # here this is the default reduction for loss_nsp, i.e. averaged;
    nsp_loss = loss_nsp(nsp_preds, nsp_labels)
    
    # combine the two losses;
    sum_of_losses = mlm_loss + nsp_loss
    return mlm_loss, nsp_loss, sum_of_losses

In [8]:
def train_loop(bert_model, loss_mlm, loss_nsp, optim, train_iter, iterations):
    curr_step = 1
    fifth = max(1, iterations // 5)
    for (tokens, segments, attention_masks, masked_positions, weights_for_masks,
         original_labels_for_masks, nsp_labels) in train_iter:
        # forward call and loss calculation
        mlm_loss, nsp_loss, sum_of_losses = loss_per_batch(bert_model, loss_mlm, loss_nsp,
                                                          vocab_size, tokens, segments,
                                                          attention_masks, masked_positions,
                                                          weights_for_masks,
                                                          original_labels_for_masks, nsp_labels)
        # now the backward pass
        optim.zero_grad()
        sum_of_losses.backward()
        optim.step()
        if curr_step % fifth == 0:
            print(f"mlm_loss: {mlm_loss:.5f}\tnsp_loss: {nsp_loss:.5f}\toverall_loss: {sum_of_losses:.5f}")
        curr_step += 1
        if curr_step > iterations:
            return

In [10]:
# get some GPUs and go crazy :)
# might need to to say bert_model.to(device) if you want to make the most out of it with a GPU;
train_loop(bert_model, loss_mlm, loss_nsp, adam_optim, train_iter, iterations=1)

mlm_loss: 6.57617	nsp_loss: 3.34041	overall_loss: 9.91658
