<a href="https://colab.research.google.com/github/jaroorhmodi/word2vec-and-BERT/blob/main/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#BERT (Bidirectional Encoder Representations from Transformers)

In this notebook I will be replicating the model in the paper [**BERT: Pre-Training of Deep Bidirectional Transformers for Language Understanding**](https://arxiv.org/pdf/1810.04805.pdf).

While I will be creating the model from (mostly) scratch in PyTorch, I will not go into too much detail about why Multi-Head Attention is designed the way it is and how exactly the original [Transformer](https://arxiv.org/abs/1706.03762) architecture works. I have made another (*albeit messy*) [notebook that covers that paper](https://github.com/jaroorhmodi/transformer-from-scratch).

The model will be trained on the [**WikiText-2**](https://paperswithcode.com/dataset/wikitext-2) and [**Wikitext-103**](https://paperswithcode.com/dataset/wikitext-103) datasets.

In [None]:
!pip install portalocker transformers tokenizers

In [41]:
import os
import copy
import json

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader

from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import WikiText2, WikiText103 #our datasets for this project

from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, DataCollatorForLanguageModeling

import math
import nltk
import numpy as np
import pandas as pd
import pickle
import random
import spacy
import tqdm

DATASET_small = "WikiText2"
DATASET_large = "WikiText103"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOKENIZER="basic_english"
DATA_DIRECTORY = "."

##Model Objective and Data

###Who (What) is BERT?


While BERT is a much more complex model and what it accomplishes isn't exactly akin to Word2Vec, the intuition behind both is similar. We pass in sentences and attempt to make a model learn how to represent text in a way that captures not only information about the tokens themselves but also something about their *meaning*.

Word2Vec does this by training a model on words and their context in sentences and learning  about their relationships with one another by either trying to predict context from words (*Skip-Gram*) or words from context (*CBOW*). The embeddings it produces are static for each word.

BERT trains a Transformer Encoder model on two specific objectives: *Masked Language Modeling* and *Next Sentence Prediction* to learn a wealth of information about tokens in their context and provide representations of them. Note that BERT is not simply learning static embeddings but rather representations that change based on context. Tokens in BERT are embedded using *WordPiece* embeddings.

The goal of the BERT paper was to introduce a way to represent words with a pre-trained transformer and not to make a model for a specific predictive goal. To this end it is trained in an unsupervised manner with the aforementioned MLM and NSP objectives (will be explained ahead).

###Data Processing

In [3]:
#We need to pull in the dataset and break it into sentence pairs for the NSP objective
#and we need to mask random words and create objectives for the MLM objective.
DATASET = DATASET_small
dataset_class = WikiText2 if DATASET == DATASET_small else WikiText103
data_train = dataset_class(DATA_DIRECTORY, split = "train")

OPT_VERSION = ''

TOKENS_LOCATION = os.path.join(DATA_DIRECTORY, "datasets", DATASET, DATASET.lower()[:8]+f"-{DATASET[8:]}")
TOKENIZER_LOCATION = os.path.join(DATA_DIRECTORY, "tokenizers", )
TOKENIZER_NAME = f"bert-wordpiece-{DATASET}{OPT_VERSION}" #just here to standardize naming for later
os.makedirs(TOKENIZER_LOCATION, exist_ok = True)

MAX_TOKENIZED_SENTENCE_LEN = 128 #maximum number of tokens in sentence

gen = iter(data_train)

sample = []
for i in range(20):
  sample.append(next(gen))

In [4]:
print(sample[:5])
print(len(sample))

[' \n', ' = Valkyria Chronicles III = \n', ' \n', ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n', " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk> for s

In the WikiText data, we see a lot of control characters like newlines, we want to make sure when we tokenize that we do not consider these. We also see that articles are delineated by a header given between single **=** signs and subheaders given by double, triple etc. equal signs.

We don't want header paragraphs and we don't want empty whitespace lines, so we need to preprocess the data some before we actually split them out into sentence pairs for our NSP task.

In [5]:
tokenizer = BertWordPieceTokenizer(
    clean_text = True, #removes control chars like \n
    handle_chinese_chars = False, #not anticipating chinese chars
    strip_accents = False, #keep accents in text
    lowercase = True #ignore case
)

tokenizer.train(
    files = os.path.join(TOKENS_LOCATION, "wiki.train.tokens"),
    vocab_size = 30_000 if DATASET == DATASET_small else 90_000, #bigger vocab for bigger dataset
    min_frequency = 10 if DATASET == DATASET_small else 50, #require higher freq for bigger dataset
    limit_alphabet = 1000,
    wordpieces_prefix = '##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]'] #in wikitext, <unk> is used for unknown tokens
)

tokenizer.save_model(
    TOKENIZER_LOCATION,
    TOKENIZER_NAME
)

#This is the tokenizer we will use
tokenizer = BertTokenizer.from_pretrained(os.path.join(TOKENIZER_LOCATION, TOKENIZER_NAME)+'-vocab.txt')
#adding this to handle the inbuilt <unk> token in the dataset which otherwise
#gets split into < unk > which may affect (slightly, but still) model performance
# tokenizer.add_special_tokens({'additional_special_tokens': ['<unk>']})





All of the input text for BERT training is to be of the form

    "[CLS] <SENTENCE1> [SEP] <SENTENCE2> [SEP]"
with `[PAD]` tokens added at the end as needed.

**Notice below how `token_type_ids` denotes where the first sequence ends and the next begins.**

Data preprocessing will require us to set up the two  training tasks of NSP and MLM.

In [6]:
test = (sample[1], sample[3])

# test = "\n"
print(f"{test=}")
#automatically adds [CLS] before first sentence and [SEP] after each sentence
#this is how Bert separates two sentences in the input
tokenized = tokenizer.encode_plus(test[0], test[1], add_special_tokens = True, return_tensors = "pt")
for key, val in tokenized.items():
  print(f"__{key}__")
  print(f"len={len(val)}")
  print(val)

test=(' = Valkyria Chronicles III = \n', ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n')
__input_ids__
len=1
tensor([[    1,    33,  7754,  7833,  2907,    33,     2,  2957,  5194,   746,
          7754,    23,    30,    32,   396,    34,  7833,    12,  2709,    30,
           236,   234,     4,    16,  1244,    18,  7754,   394,   381,  9510,
   

In [7]:
def new_article(para):
  #just checks if it is a new article
  return para.strip().startswith("= ") and not para.strip().startswith("= =")

def new_heading(para):
  #checks if it is a new heading
  #slightly different from the new article
  return para.strip().startswith("=")

def split_sentences(para):
  sentences = para.split('. ')

def article_sentence_pairs(article_sentences):
  #See note below
  sentence_pairs = []
  for i, sentence in enumerate(article_sentences):
    if i == len(article_sentences)-1:
      break
    else:
      sentence_pairs.append((sentence, article_sentences[i+1]))
  return sentence_pairs

def collect_sentence_pairs(dataset_iter):
  articles = []
  current_article = []
  for para in tqdm.auto.tqdm(dataset_iter):
    if new_article(para):
      #new article, stop collecting sentences
      if len(current_article) > 0:
        articles += article_sentence_pairs(current_article)
      current_article = []
      continue
    if para.strip() == '' or new_heading(para):
      #new heading or empty line, skip and continue collecting
      continue
    current_article += para.strip().split('. ')
  return articles


NOTE: there are a few ways we could have split this dataset. I chose to go with one that splits at `". "` for this approach but we could have tokenized the data first and split according to maximum sequence length.

In [8]:
#proper pipeline management, since we iterated once on the old data_train object we reset iteration
gen = iter(dataset_class(DATA_DIRECTORY, split = "train"))
train_sentence_pairs = collect_sentence_pairs(gen)
len(train_sentence_pairs)

0it [00:00, ?it/s]

80096

In [9]:
gen = iter(dataset_class(DATA_DIRECTORY, split = "valid"))
valid_sentence_pairs = collect_sentence_pairs(gen)
len(valid_sentence_pairs)


0it [00:00, ?it/s]

8414

In [10]:
train_sentence_pairs[900000:900010]

[]

In [11]:
#Persist sentence pairs to file
os.makedirs(os.path.join(DATA_DIRECTORY, "datasets", "pairs_data"), exist_ok = True)


SENTENCE_PAIRS_LOCATION = os.path.join(TOKENS_LOCATION, "pairs")
os.makedirs(SENTENCE_PAIRS_LOCATION, exist_ok=True)


#pickle 200k sentence pairs at a time into various pickle files
def pickle_pairs(pairs, prefix="train", chunk_size = 200_000):
  i = 0
  while chunk_size*(i+1) < len(pairs):
    with open(os.path.join(SENTENCE_PAIRS_LOCATION, f"{prefix}-{i}.pkl"), "wb") as f:
      pickle.dump(pairs[chunk_size*i:chunk_size*(i+1)], f)
    i+=1

pickle_pairs(train_sentence_pairs)
pickle_pairs(valid_sentence_pairs, "valid")


So we see now that we are able to make pairs of sentences next to one another. Note that our particular method makes it so the first sentence from the immediately following paragraph is treated as a "next sentence" but the first sentence of the following article is not.

###The Two Training Objectives

  The following examples are straight from the paper and used to illustrate the NSP objective but can be used to explain the MLM objective as well.
    
    Input: [CLS] the man went to [MASK] store [SEP] he bought a gallon [MASK] milk [SEP]
    Label: IsNext

    Input: [CLS] the man went to [MASK] store [SEP] penguin [MASK] are flight ##less birds [SEP]
    Label: NotNext

####Creating a Custom Dataset
####Masked Language Modeling
The **Masked Language Model** task masks out tokens with a given probability (and with some small probability substitutes with a random token) in the inputs and asks the model to predict the tokens that were masked out.

####Next Sentence Prediction

The **Next Sentence Prediction** task takes sentence pairs and creates a balanced classification task, where half of the time the following sentence is actually the next sentence and half of the time it is a random sentence. Then the model is trained to correctly predict whether or not the second sentence is the actual next sentence.

We have already done half of the work for this task by creating positive next sentence pairs for each sentence in the text. All we have to do is replace the second sentence in each pair with a random sentence half of the time.

In [12]:
#We will create a custom Dataset class for our purposes
class BERTDataset(Dataset):
  def __init__(self, sentence_pairs, tokenizer, seq_len=512):
    self.sentence_pairs = sentence_pairs
    self.num_pairs = len(sentence_pairs)
    self.tokenizer = tokenizer
    self.seq_len = seq_len

  def __len__(self):
    return self.num_pairs

  #The key component of a Dataset is the __getitem__ function
  def __getitem__(self, idx):
    #get nsp entry, encode, return
    sentence1, sentence2, isNextLabel = self.get_nsp_entry(idx)

    #TODO: CONSIDERING NOT ENCODING, CHECK LATER
    encoded_pair = tokenizer.encode_plus(
        sentence1,
        sentence2,
        add_special_tokens = True,
        max_length = self.seq_len,
        padding = 'max_length',
        truncation = 'longest_first',
        return_special_tokens_mask = True,
        return_tensors = "pt"
    )

    encoded_pair['nsp_labels'] = torch.Tensor([isNextLabel]).int()

    return encoded_pair


  def get_nsp_entry(self, idx):
    #Implement NSP randomization here
    sent1, sent2 = self.sentence_pairs[idx]
    if random.random() >= 0.5:
      #this is the case where we give positive nsp example
      return sent1, sent2, 1
    else:
      return sent1, self.get_non_next_sentence(idx), 0

  def get_non_next_sentence(self, idx):
    random_idx = random.randrange(self.num_pairs)
    while random_idx == idx:
      """
      this is just here for the small chance that
      our random index maps to the same one and gives
      us a false pair where the actual next sentence
      is mislabeled as NotNext
      """
      random_idx = random.randrange(self.num_pairs)
    return self.sentence_pairs[random_idx][1]



#####Example using the Transformers Library for MLM

We can use the `transformers.DataCollatorForLanguageModeling` class to handle the masking for us.

We see [in the documentation](https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/data/data_collator.py#L751C1-L751C108) that exactly like in the paper, `DataCollatorForLanguageModeling` will mask tokens 80% of the time, replace with a random token 10% of the time, and leave as is 10% of the time. This is all conditional on the `mlm_probability` value passed into the collator which is `0.15` in the paper.

In [13]:
collator = DataCollatorForLanguageModeling(tokenizer = tokenizer, mlm = True, mlm_probability = 0.15)

In [14]:
#We recreate our tokenization example from earlier,
#we just want to create a sentence pair for illustration
feats = tokenizer.encode_plus(
    sample[1],
    sample[3],
    return_special_tokens_mask = True,
    return_tensors = "pt",
    max_length = 256,
    truncation=True,
    padding = 'max_length'
)


In [15]:
feats['nsp_label'] = torch.Tensor([1])

In [16]:
from pprint import pprint

masked = collator([feats])
pprint(masked)

{'attention_mask': tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0]]]),
 'input_ids': tensor([[[    1,    33,  7754,  7833,  2907,    33,     2,     3,  5194,   

In [17]:
tokenizer.decode(masked['input_ids'].squeeze())

'[CLS] = valkyria chronicles iii = [SEP] [MASK]jo no valkyria 3 : < [MASK] > chronicles ( [MASK] : 戦 場 [UNK], [MASK]. valkyria advert the [MASK] 3 ), [MASK] referred weekly as valkyria chronicles iii outside japan, is [MASK] tactical role @ - @ playing video game developed by sega and media. vision for the playstation portable. released in january 2011 in japan, it [MASK] the [MASK] game in the valkyria series [MASK] < [MASK] > the same fusion of tactical [MASK] real @ - [MASK] time gameplay as its [MASK] [MASK] the [MASK] runs parallel to the first game and follows the nueces nameless " [MASK] [MASK] penal military unit witnesses the nation of gall [MASK] during [MASK] second europan war who perform secret black operations and are pitted against the imperial unit " < [MASK] > [MASK] ". [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

In [18]:
tokenizer.decode(feats['input_ids'].squeeze())

'[CLS] = valkyria chronicles iii = [SEP] senjo no valkyria 3 : < unk > chronicles ( japanese : 戦 場 [UNK], lit. valkyria of the battlefield 3 ), commonly referred to as valkyria chronicles iii outside japan, is a tactical role @ - @ playing video game developed by sega and media. vision for the playstation portable. released in january 2011 in japan, it is the third game in the valkyria series. < unk > the same fusion of tactical and real @ - @ time gameplay as its predecessors, the story runs parallel to the first game and follows the " nameless ", a penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are pitted against the imperial unit " < unk > raven ". [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD

##Model Architecture

Here we will create the architecture of the model from scratch. I will not go into too much detail about it since it is so similar to the original transformer encoder (which I mentioned above at the beginning of this notebook). But there will be some helpful annotations and comments.

###Embedding

The Transformer Encoder (the decoder as well, for that matter) uses a Positional Embedding to provide information about the positions of tokens in the sequence. This is necessary because attention is bidirectional and has no way to account for absolute position of tokens in a sequence on its own.

We also have to find a way to include information about the two segments (since we pass in pairs of sentences). Our tokenizer provides information about the segments in a vector of 0s and 1s as per the paper specs. We add an embedding layer that takes two tokens in and outputs and embedding of the same dimension as our token embedding.

This segment embedding represents one of the few differences between the model architecture in BERT and the original Transformer.

In [19]:
class PositionalEncoding(torch.nn.Module):

  def __init__(self, d_model, max_len=128):
    super().__init__()

    # Compute the positional encodings once in log space.
    pe = torch.zeros(max_len, d_model).float().to(DEVICE)
    pe.require_grad = False

    for pos in range(max_len):
        # for each dimension of the each position
        for i in range(0, d_model, 2):
            pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
            pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))

    # include the batch size
    self.pe = pe.unsqueeze(0)

  def forward(self, x):
    return self.pe

class BERTEmbeddingLayer(nn.Module):
  """
  The essential steps of embedding here are:
  Embed tokens to our chosen embedding_dim
  Add positional encoding to the base embedding
  Add segment embeddings (to represent which sentence of two each token is part of)

  Note there are only two values for segment labels.
  The padding tokens after the second segment are labeled 0 just like the first.
  """
  def __init__(self, vocab_size, embedding_dim, seq_len = 512, dropout = 0.0):
    super(BERTEmbeddingLayer, self).__init__()
    self.positional_encoding = PositionalEncoding(embedding_dim, seq_len)
    self.token_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = 0)
    self.segment_embedding = nn.Embedding(2, embedding_dim)

  def forward(self, x, segment_labels):
    # print(f"BERTEmbeddingLayer {x.shape=}")
    # print(f"BERTEmbeddingLayer {segment_labels.shape=}")
    seg = self.segment_embedding(segment_labels)
    tok = self.token_embedding(x)
    pos = self.positional_encoding(x)
    # print(f"BERTEmbeddingLayer {seg.shape=} {seg.get_device()=}")
    # print(f"BERTEmbeddingLayer {pos.shape=} {pos.get_device()=}")
    # print(f"BERTEmbeddingLayer {tok.shape=} {tok.get_device()=}")
    return seg+pos+tok

###Encoder Architecture

The Transformer Encoder (and BERT, by extension) only uses self-attention in a manner where `query`, `key`, and `value` are all the same since we don't employ an "encoder memory" anywhere the way we would with the decoder.

In the `BERT_base` model, there are 12 instead of 6 stacked blocks, `d_model = 768`, there are 12 heads of attention. These represent some of the differences in configuration between the original Transformer and BERT encoders but aren't actual architecture differences.

Another minor difference is that the feedforward activation function is chosen to be `gelu` instead of `relu`.

In [20]:
clone_layers = lambda layer, n: nn.ModuleList([copy.deepcopy(layer) for _ in range(n)])

In [21]:
#Implement MultiHeadAttention
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads, dropout = 0.1):
    super(MultiHeadAttention, self).__init__()
    self.d_model = d_model
    self.d_k = d_model // num_heads
    self.num_heads = num_heads
    self.dropout = nn.Dropout(p = dropout)

    self.query, self.key, self.value, self.out = clone_layers(nn.Linear(d_model, d_model), 4)

  def forward(self, query, key, value, attention_mask):
    #The shapes of these tensors are annotated for easier intuition
    #b = batch_size, s = seq_len, h = num_heads
    #(b, s, d_model)
    Q = self.query(query)
    K = self.key(key)
    V = self.value(value)

    #(b, s, d_model) -> (b, s, h, d_k) -> (b, h, s, d_k)
    # print(f"MHA {Q.shape=} {K.shape=} {V.shape=}")
    Q = Q.view(Q.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    K = K.view(K.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    V = V.view(V.size(0), -1, self.num_heads, self.d_k).transpose(1, 2)
    # print(f"MHA {Q.shape=} {K.shape=} {V.shape=}")

    #Note that MHA does Scaled Dot Product attention over each head
    #and then concatenates. So scores is actually done for each head.
    #(b, h, s, d_k) x (b, h, d_k, s) -> scalar * (b, h, s, s)
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
    # print(f"MHA {scores.shape=}")
    # print(f"MHA {attention_mask.shape=}")

    #fill masked scores with low value (imagine using -inf)
    #to minimize impact on softmax output
    #attention mask shape has to be (b, 1, 1, s)
    #so we need to unsqueeze
    #fill does not change shape
    #(b, h, s, s)
    scores = scores.masked_fill(attention_mask.unsqueeze(1) == 0, -1e9)

    #softmax does not change shape
    #(b, h, s, s)
    weights = F.softmax(scores, dim = -1)
    weights = self.dropout(weights)

    #(b, h, s, s) x (b, h, s, d_k) -> (b, h, s, d_k)
    #this would be the memory or context in transformer
    memory = torch.matmul(weights, V)

    #(b, h, s, d_k) -> (b, s, h, d_k) -> (b, s, h * d_k) = (b, s, d_model)
    #back to original shape
    memory = memory.transpose(1, 2).contiguous().view(memory.size(0), -1, self.num_heads * self.d_k)

    #(b, s, d_model) -> (b, s, d_model)
    return self.out(memory)


The original BERT paper maintains the relationship between `d_model` and `d_ff` in the feedforward layer where `d_ff = 4 * d_model`.

In [22]:
#Implement FF Layer for Encoder
class FeedForwardLayer(nn.Module):
  def __init__(self, d_model, d_ff = 3072, dropout = 0.0):
    #Note in the BERT paper, BERT_base has d_ff = 3072 = 4 * d_model = 4 * 768
    super(FeedForwardLayer, self).__init__()
    self.linear1 = nn.Linear(d_model, d_ff)
    self.linear2 = nn.Linear(d_ff, d_model)
    self.dropout = nn.Dropout(p = dropout)
    self.activation = nn.GELU()

  def forward(self, x):
    return self.linear2(self.dropout(self.activation(self.linear1(x))))

Now we implement the EncoderLayer. We need to use layernormalization here.

In [23]:
#All default values are selected based on BERT_base in the paper
class EncoderLayer(nn.Module):
  def __init__(
      self,
      d_model=768,
      num_heads=12,
      d_ff = 3072,
      dropout = 0.1
    ):
    super(EncoderLayer, self).__init__()
    self.mha = MultiHeadAttention(d_model, num_heads, dropout)
    self.ff = FeedForwardLayer(d_model, d_ff, dropout)
    self.norm = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(p = dropout)

  def forward(self, embeddings, attention_mask):
    #Shapes: b = batch_size, s = seq_len
    #embeddings: (b, s, d_model)
    #mask: (b, 1, 1, s)
    # print(f"ENCODER LAYER 0 {embeddings.shape=} {attention_mask.shape=}")

    #attention_out: (b, s, d_model)
    attention_out = self.dropout(self.mha(embeddings, embeddings, embeddings, attention_mask))
    # print(f"ENCODER LAYER 1 {attention_out.shape=}")

    #attention_out: (b, s, d_model)
    attention_out = self.norm(attention_out+embeddings)
    # print(f"ENCODER LAYER 2 {attention_out.shape=}")

    #ff_out: (b, s, d_model)
    ff_out = self.dropout(self.ff(attention_out))
    # print(f"ENCODER LAYER 3 {ff_out.shape=}")

    return self.norm(ff_out + attention_out)


####BERT Model

Now we implement the model by putting together the above classes we created.

In [24]:
class BERT(nn.Module):
  def __init__(
      self,
      vocab_size,
      d_model = 768,
      blocks = 12,
      num_heads = 12,
      seq_len = 512,
      dropout = 0.1
    ):
    super(BERT, self).__init__()
    self.vocab_size = vocab_size
    self.d_model = d_model
    self.d_ff = 4 * d_model #as noted above, ff hidden layer dim is 4*d_model
    self.blocks = blocks
    self.num_heads = num_heads
    self.seq_len = seq_len
    self.dropout = dropout

    self.embedding = BERTEmbeddingLayer(
        self.vocab_size,
        self.d_model,
        seq_len = self.seq_len,
        dropout = self.dropout
    )

    self.encoder_blocks = clone_layers(EncoderLayer(d_model, num_heads, self.d_ff, dropout), blocks)

  def forward(self, x, segment_info, attention_mask):
    #Our BertTokenizer also passes out the attention mask for padding
    #this tells the Encoder to not learn padding information
    x = self.embedding(x, segment_info)

    for block in self.encoder_blocks:
      x = block(x, attention_mask)

    return x


The objective of BERT pretraining is to train the Encoder we have defined above. Once we train on NSP and MLM respectively, we only keep the pre-trained encoder in our final product.

We still need to create models corresponding to our NSP and MLM tasks that will be used to train the BERT encoder we defined above.

In [25]:
class NextSentencePrediction(nn.Module):
  def __init__(self, d_model):
    super(NextSentencePrediction, self).__init__()
    self.linear = nn.Linear(d_model, 2)
    self.softmax = nn.LogSoftmax(dim = -1)

  def forward(self, x):
    # print(f"NSP {x.shape=}")
    xlinear = self.linear(x[:, 0])
    # print(f"NSP {xlinear.shape=}")
    xsoftmax= self.softmax(xlinear)
    # print(f"NSP {xsoftmax.shape=}")
    # use only the first token which is the [CLS]
    # return self.softmax(self.linear(x[:, 0]))
    return xsoftmax


class MaskedLanguageModel(nn.Module):
  def __init__(self, d_model, vocab_size):
    super(MaskedLanguageModel, self).__init__()
    self.linear = nn.Linear(d_model, vocab_size)
    self.softmax = nn.LogSoftmax(dim = -1)

  def forward(self, x):
    # print(f"MLM {x.shape=}")
    xlinear = self.linear(x)
    # print(f"MLM {xlinear.shape=}")
    xsoftmax= self.softmax(xlinear)
    # print(f"MLM {xsoftmax.shape=}")
    return xsoftmax

    # return self.softmax(self.linear(x))

class BERTmodel(nn.Module):
  #This model trains on two tasks at once
  #We see that NSP and MLM layers are simple dense layers
  #The goal is to use them as a way to learn params for BERT's representation
  def __init__(self, BERTencoder):
    super(BERTmodel, self).__init__()
    self.BERT = BERTencoder
    self.d_model = self.BERT.d_model
    self.NSP = NextSentencePrediction(self.BERT.d_model)
    self.MLM = MaskedLanguageModel(self.BERT.d_model, self.BERT.vocab_size)

  def forward(self, x, segment_info, attention_mask):
    encoded = self.BERT(x, segment_info, attention_mask)
    # print(f"BERTmodel {encoded.shape=}")
    return self.NSP(encoded), self.MLM(encoded)


##Training

Now we just need to put all of this together and train our model!

In the paper the loss function is given by the sum of the mean masked LM probabilities and the nsp likelihood.

The optimizer they used was Adam with `lr` = 1e-4, $\beta_1$ = 0.9, $\beta_2$ = 0.999, L2 weight decay = 0.01, `warmup_steps` = 10_000, and linear decay of learning rate.

In [26]:
#We make this custom optimizer:
#Link that I got this from: https://stackoverflow.com/questions/65343377/adam-optimizer-with-warmup-on-pytorch
class NOAMOptimizer:
  "Optim wrapper that implements rate."
  def __init__(self, optimizer, model_size, warmup):
    self.optimizer = optimizer
    self._step = 0
    self.warmup = warmup
    self.model_size = model_size
    self._rate = 0

  def state_dict(self):
    """Returns the state of the warmup scheduler as a :class:`dict`.
    It contains an entry for every variable in self.__dict__ which
    is not the optimizer.
    """
    return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

  def load_state_dict(self, state_dict):
    """Loads the warmup scheduler's state.
    Arguments:
        state_dict (dict): warmup scheduler state. Should be an object returned
            from a call to :meth:`state_dict`.
    """
    self.__dict__.update(state_dict)

  def step(self):
    "Update parameters and rate"
    self._step += 1
    rate = self.rate()
    for p in self.optimizer.param_groups:
        p['lr'] = rate
    self._rate = rate
    self.optimizer.step()

  def rate(self, step = None):
    "Implement `lrate` above"
    if step is None:
        step = self._step
    return (self.model_size ** (-0.5) *
        min(step ** (-0.5), step * self.warmup ** (-1.5)))

  def zero_grad(self):
    self.optimizer.zero_grad()

In [27]:
def collate_for_mlm(batch):
  collated = collator(batch)
  collated['input_ids'] = collated['input_ids'].squeeze()
  collated['token_type_ids'] = collated['token_type_ids'].squeeze()
  collated['nsp_labels'] = collated['nsp_labels'].squeeze()
  collated['labels'] = collated['labels'].squeeze()
  collated['labels'][collated['labels']<0] = 0
  return collated



In [28]:
train_dataset = BERTDataset(train_sentence_pairs, tokenizer)
train_dataloader = iter(DataLoader(train_dataset, collate_fn=collate_for_mlm, batch_size=16, shuffle=True))

In [29]:
sample_batch = next(train_dataloader)
sample_batch['nsp_labels']

tensor([0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1])

In [79]:
#Let's make a trainer class to hold all of the information we need for training.
#W
class BERTTrainer:
  def __init__(
      self,
      model,
      train_dataloader,
      valid_dataloader=None,
      lr= 1e-4,
      weight_decay=0.01,
      betas=(0.9, 0.999),
      warmup_steps=10_000,
      log_frequency=10,
      miniters_denom=50
  ):
    self.model = model.to(DEVICE)
    self.train_data = train_dataloader
    self.valid_data = valid_dataloader

    #Note that the loss function is described in the paper to be sum of likelihoods.
    self.loss_fn = nn.NLLLoss(ignore_index=0) #padding index

    adam_optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
    self.optimizer = NOAMOptimizer(adam_optimizer, model.d_model, warmup_steps)

    self.log_frequency = log_frequency
    self.miniters_denom = miniters_denom

  def train(self, epochs):
    for epoch in range(epochs):
      self.train_epoch(epoch)
      self.valid_epoch(epoch)

  def train_epoch(self, epoch):
    epoch_loss = 0.0
    total_correct = 0
    total_samples = 0

    dataloader = self.train_data

    data_iter = tqdm.tqdm(
      enumerate(dataloader),
      desc=f"EPOCH_TRAIN:{epoch}",
      total=len(dataloader),
      bar_format="{l_bar}{bar}{r_bar}",
      miniters=int(len(dataloader)/self.miniters_denom),
      position=0,
      leave=True
    )

    for i, batch in data_iter:
      #send tensors to DEVICE
      inputs = {
        "input_ids": batch['input_ids'].to(DEVICE),
        "attention_mask": batch['attention_mask'].to(DEVICE),
        "segment_info": batch['token_type_ids'].to(DEVICE),
      }
      outputs = {
          "nsp_labels": batch['nsp_labels'].to(DEVICE),
          "mlm_labels": batch['labels'].to(DEVICE)
      }

      #train mode
      self.model.train()

      #forward pass
      NSP_prediction, MLM_prediction = self.model(inputs['input_ids'], inputs['segment_info'], inputs['attention_mask'])

      #compute losses
      # print(f"TRAINING {outputs['nsp_labels'].shape=}")
      # print(f"TRAINING {outputs['mlm_labels'].shape=}")
      NSP_loss = self.loss_fn(NSP_prediction, outputs['nsp_labels'])
      MLM_loss = self.loss_fn(MLM_prediction.transpose(1,2), outputs['mlm_labels'])
      loss = NSP_loss + MLM_loss

      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()

      epoch_loss += loss.item()
      total_correct += NSP_prediction.argmax(dim=-1).eq(outputs['nsp_labels']).sum().item()
      total_samples += NSP_prediction.shape[0]

      post_fix = {
          "epoch": epoch,
          "iter": i,
          "avg_loss": epoch_loss / (i + 1),
          "avg_acc": total_correct / total_samples * 100,
          "loss": loss.item()
      }

      if i % self.log_frequency == 0:
          data_iter.write(str(post_fix))
    print(
      f"EP{epoch}, TRAIN: \
      avg_loss={epoch_loss / len(data_iter)}, \
      total_acc={total_correct * 100.0 / total_samples}"
    )

  def valid_epoch(self, epoch):
    epoch_loss = 0.0
    total_correct = 0
    total_samples = 0

    dataloader = self.valid_data

    data_iter = tqdm.tqdm(
      enumerate(dataloader),
      desc=f"EPOCH_VALID:{epoch}",
      total=len(dataloader),
      bar_format="{l_bar}{bar}{r_bar}"
    )

    for i, batch in data_iter:
      #send tensors to DEVICE
      inputs = {
        "input_ids": batch['input_ids'].to(DEVICE),
        "attention_mask": batch['attention_mask'].to(DEVICE),
        "segment_info": batch['token_type_ids'].to(DEVICE),
      }
      outputs = {
          "nsp_labels": batch['nsp_labels'].to(DEVICE),
          "mlm_labels": batch['labels'].to(DEVICE)
      }

      #train mode
      self.model.eval()

      with torch.no_grad():
        #forward pass
        NSP_prediction, MLM_prediction = self.model(inputs['input_ids'], inputs['segment_info'], inputs['attention_mask'])

        #compute losses
        NSP_loss = self.loss_fn(NSP_prediction, outputs['nsp_labels'])
        MLM_loss = self.loss_fn(MLM_prediction.transpose(1,2), outputs['mlm_labels'])
        loss = NSP_loss + MLM_loss

        epoch_loss += loss.item()

      total_correct += NSP_prediction.argmax(dim=-1).eq(outputs['nsp_labels']).sum().item()
      total_samples += NSP_prediction.shape[0]

      post_fix = {
          "epoch": epoch,
          "iter": i,
          "avg_loss": epoch_loss / (i + 1),
          "avg_acc": total_correct / total_samples * 100,
          "loss": loss.item(),
          "total_correct": total_correct,
          "total_samples": total_samples
      }

      if i % self.log_frequency == 0:
          data_iter.write(str(post_fix))
          # print(
          #   f"EP{epoch}, VALID: \
          #   avg_loss={epoch_loss / len(data_iter)}, \
          #   total_acc={total_correct * 100.0 / total_samples}"
          # )



In [80]:
D_MODEL = 384 #in the paper this is 768
N_LAYERS = 2 #in the paper this is 12
HEADS = 6 #in the paper this is 12
DROPOUT = 0.1
SEQ_LEN = 512 #512 is a good number for this
BATCH_SIZE = 32

LR= 1e-4,
WEIGHT_DECAY=0.01,
BETAS=(0.9, 0.999),
WARMUP_STEPS=10_000,
LOG_FREQ=100

train_dataset = BERTDataset(train_sentence_pairs, tokenizer)
valid_dataset = BERTDataset(valid_sentence_pairs, tokenizer)

train_dataloader = DataLoader(train_dataset, collate_fn=collate_for_mlm, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, collate_fn=collate_for_mlm, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

bert_encoder = BERT(
  vocab_size=len(tokenizer.vocab)+1, #because we added an additional special token
  d_model=D_MODEL,
  blocks=N_LAYERS,
  num_heads=HEADS,
  seq_len=SEQ_LEN,
  dropout= DROPOUT
).to(DEVICE)

bert_lm = BERTmodel(bert_encoder).to(DEVICE)
bert_trainer = BERTTrainer(
    bert_lm,
    train_dataloader,
    valid_dataloader,
    # lr = LR,
    # weight_decay = WEIGHT_DECAY,
    # betas = (BETAS),
    # warmup_steps = WARMUP_STEPS,
    log_frequency = LOG_FREQ
)

In [None]:
epochs = 10

bert_trainer.train(epochs)

EPOCH_TRAIN:0:   0%|          | 0/2503 [00:00<?, ?it/s]

{'epoch': 0, 'iter': 0, 'avg_loss': 10.484803199768066, 'avg_acc': 59.375, 'loss': 10.484803199768066}


EPOCH_TRAIN:0:   4%|▍         | 101/2503 [00:51<21:02,  1.90it/s]

{'epoch': 0, 'iter': 100, 'avg_loss': 10.214528867513826, 'avg_acc': 48.32920792079208, 'loss': 9.941876411437988}


EPOCH_TRAIN:0:   8%|▊         | 201/2503 [01:44<20:14,  1.90it/s]

{'epoch': 0, 'iter': 200, 'avg_loss': 9.936137033339163, 'avg_acc': 49.424751243781095, 'loss': 9.292469024658203}


EPOCH_TRAIN:0:  12%|█▏        | 301/2503 [02:37<19:23,  1.89it/s]

{'epoch': 0, 'iter': 300, 'avg_loss': 9.62289420473219, 'avg_acc': 49.52242524916943, 'loss': 8.81309986114502}


EPOCH_TRAIN:0:  16%|█▌        | 401/2503 [03:30<18:28,  1.90it/s]

{'epoch': 0, 'iter': 400, 'avg_loss': 9.350067757014324, 'avg_acc': 49.45448877805486, 'loss': 8.286011695861816}


EPOCH_TRAIN:0:  20%|██        | 501/2503 [04:22<17:32,  1.90it/s]

{'epoch': 0, 'iter': 500, 'avg_loss': 9.115131813133072, 'avg_acc': 49.73802395209581, 'loss': 7.94939661026001}


EPOCH_TRAIN:0:  24%|██▍       | 601/2503 [05:15<16:42,  1.90it/s]

{'epoch': 0, 'iter': 600, 'avg_loss': 8.902047424665504, 'avg_acc': 49.77641430948419, 'loss': 7.607085227966309}


EPOCH_TRAIN:0:  28%|██▊       | 701/2503 [06:08<15:51,  1.89it/s]

{'epoch': 0, 'iter': 700, 'avg_loss': 8.707529137376032, 'avg_acc': 49.6478245363766, 'loss': 7.242195129394531}


EPOCH_TRAIN:0:  32%|███▏      | 801/2503 [07:00<15:00,  1.89it/s]

{'epoch': 0, 'iter': 800, 'avg_loss': 8.535817580871964, 'avg_acc': 49.64107365792759, 'loss': 7.410141468048096}


EPOCH_TRAIN:0:  36%|███▌      | 901/2503 [07:53<14:08,  1.89it/s]

{'epoch': 0, 'iter': 900, 'avg_loss': 8.387710659141414, 'avg_acc': 49.93756936736959, 'loss': 6.747497081756592}


EPOCH_TRAIN:0:  40%|███▉      | 1001/2503 [08:46<13:12,  1.89it/s]

{'epoch': 0, 'iter': 1000, 'avg_loss': 8.265129548567277, 'avg_acc': 49.94692807192807, 'loss': 6.836628437042236}


EPOCH_TRAIN:0:  44%|████▍     | 1101/2503 [09:38<12:22,  1.89it/s]

{'epoch': 0, 'iter': 1100, 'avg_loss': 8.162600028309143, 'avg_acc': 49.77577202543142, 'loss': 7.186311721801758}


EPOCH_TRAIN:0:  48%|████▊     | 1201/2503 [10:31<11:29,  1.89it/s]

{'epoch': 0, 'iter': 1200, 'avg_loss': 8.07509328959685, 'avg_acc': 49.729392173189005, 'loss': 6.756588935852051}


EPOCH_TRAIN:0:  49%|████▉     | 1224/2503 [10:43<11:21,  1.88it/s]

In [78]:
del bert_trainer.model
del bert_trainer.train_data
del bert_trainer.valid_data
del bert_trainer
torch.cuda.empty_cache()

In [52]:
def save_model(trainer, version=0):
  directory = os.path.join(DATA_DIRECTORY, 'models', f'model_{DATASET}_v{version}')
  os.makedirs(directory, exist_ok=True)
  torch.save(trainer.model.state_dict(), os.path.join(directory, 'model.pth'))
  torch.save(trainer.optimizer.state_dict(), os.path.join(directory, 'optimizer.pth'))

  optim_dict = trainer.optimizer.__dict__
  optim_dict['optimizer'] = {key: value for key, value in optim_dict['optimizer'].items()}
  specs_dictionary = {
      'trainer': str(trainer.__dict__),
      'model': str(trainer.model.__dict__),
      'train_data': str(trainer.train_data.__dict__),
      'valid_data': str(trainer.valid_data.__dict__),
      'optimizer': str(trainer.optimizer.__dict__),
  }
  with open(os.path.join(directory, 'specs.json'), 'w') as f:
      json.dump(specs_dictionary, f, indent=4)



In [57]:
save_model(bert_trainer)

In [75]:
def DOWNLOAD_SESSION():
  from google.colab import files
  import shutil
  shutil.make_archive('/content/session_data', 'zip', '/content')
  files.download('/content/session_data.zip')

In [76]:
DOWNLOAD_SESSION()

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>