<a href="https://colab.research.google.com/github/jaroorhmodi/word2vec-and-BERT/blob/main/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#BERT (Bidirectional Encoder Representations from Transformers)

In this notebook I will be replicating the model in the paper [**BERT: Pre-Training of Deep Bidirectional Transformers for Language Understanding**](https://arxiv.org/pdf/1810.04805.pdf).

While I will be creating the model from (mostly) scratch in PyTorch, I will not go into too much detail about why Multi-Head Attention is designed the way it is and how exactly the original [Transformer](https://arxiv.org/abs/1706.03762) architecture works. I have made another (*albeit messy*) [notebook that covers that paper](https://github.com/jaroorhmodi/transformer-from-scratch).

The model will be trained on the [**WikiText-2**](https://paperswithcode.com/dataset/wikitext-2) and [**Wikitext-103**](https://paperswithcode.com/dataset/wikitext-103) datasets.

In [1]:
!pip install portalocker transformers tokenizers

Collecting portalocker
  Downloading portalocker-2.8.2-py3-none-any.whl (17 kB)
Collecting transformers
  Downloading transformers-4.34.0-py3-none-any.whl (7.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m49.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers
  Downloading tokenizers-0.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.16.4 (from transformers)
  Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
Collecting safetensors>=0.3.1 (from transformers)
  Downloading safetensors-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[

In [27]:
import os

import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import Dataset, DataLoader

import nltk
import numpy as np
import pandas as pd
import pickle
import random
import spacy

from tokenizers import BertWordPieceTokenizer
from transformers import BertTokenizer, DataCollatorForLanguageModeling

from torchtext.data import to_map_style_dataset
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import WikiText2, WikiText103 #our datasets for this project

from tqdm.auto import tqdm

DATASET_small = "WikiText2"
DATASET_large = "WikiText103"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TOKENIZER="basic_english"
DATA_DIRECTORY = "."

##Model Objective and Data

###Who (What) is BERT?


While BERT is a much more complex model and what it accomplishes isn't exactly akin to Word2Vec, the intuition behind both is similar. We pass in sentences and attempt to make a model learn how to represent text in a way that captures not only information about the tokens themselves but also something about their *meaning*.

Word2Vec does this by training a model on words and their context in sentences and learning  about their relationships with one another by either trying to predict context from words (*Skip-Gram*) or words from context (*CBOW*). The embeddings it produces are static for each word.

BERT trains a Transformer Encoder model on two specific objectives: *Masked Language Modeling* and *Next Sentence Prediction* to learn a wealth of information about tokens in their context and provide representations of them. Note that BERT is not simply learning static embeddings but rather representations that change based on context. Tokens in BERT are embedded using *WordPiece* embeddings.

The goal of the BERT paper was to introduce a way to represent words with a pre-trained transformer and not to make a model for a specific predictive goal. To this end it is trained in an unsupervised manner with the aforementioned MLM and NSP objectives (will be explained ahead).

###Data Processing

In [3]:
#We need to pull in the dataset and break it into sentence pairs for the NSP objective
#and we need to mask random words and create objectives for the MLM objective.
DATASET = DATASET_small
dataset_class = WikiText2 if DATASET == DATASET_small else WikiText103
data_train = dataset_class(DATA_DIRECTORY, split = "train")

OPT_VERSION = ''

TOKENS_LOCATION = os.path.join(DATA_DIRECTORY, "datasets", DATASET, DATASET.lower()[:8]+f"-{DATASET[8:]}")
TOKENIZER_LOCATION = os.path.join(DATA_DIRECTORY, "tokenizers", )
TOKENIZER_NAME = f"bert-wordpiece-{DATASET}{OPT_VERSION}" #just here to standardize naming for later
os.makedirs(TOKENIZER_LOCATION, exist_ok = True)

MAX_TOKENIZED_SENTENCE_LEN = 128 #maximum number of tokens in sentence

gen = iter(data_train)

sample = []
for i in range(20):
  sample.append(next(gen))

In [4]:
print(sample[:5])
print(len(sample))

[' \n', ' = Valkyria Chronicles III = \n', ' \n', ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n', " The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more <unk> for s

In the WikiText data, we see a lot of control characters like newlines, we want to make sure when we tokenize that we do not consider these. We also see that articles are delineated by a header given between single **=** signs and subheaders given by double, triple etc. equal signs.

We don't want header paragraphs and we don't want empty whitespace lines, so we need to preprocess the data some before we actually split them out into sentence pairs for our NSP task.

In [5]:
tokenizer = BertWordPieceTokenizer(
    clean_text = True, #removes control chars like \n
    handle_chinese_chars = False, #not anticipating chinese chars
    strip_accents = False, #keep accents in text
    lowercase = True #ignore case
)

tokenizer.train(
    files = os.path.join(TOKENS_LOCATION, "wiki.train.tokens"),
    vocab_size = 30_000 if DATASET == DATASET_small else 90_000, #bigger vocab for bigger dataset
    min_frequency = 10 if DATASET == DATASET_small else 50, #require higher freq for bigger dataset
    limit_alphabet = 1000,
    wordpieces_prefix = '##',
    special_tokens=['[PAD]', '[CLS]', '[SEP]', '[MASK]', '[UNK]']
)

tokenizer.save_model(
    TOKENIZER_LOCATION,
    TOKENIZER_NAME
)

#This is the tokenizer we will use
tokenizer = BertTokenizer.from_pretrained(os.path.join(TOKENIZER_LOCATION, TOKENIZER_NAME)+'-vocab.txt')

['./tokenizers/bert-wordpiece-WikiText2-vocab.txt']


All of the input text for BERT training is to be of the form

    "[CLS] <SENTENCE1> [SEP] <SENTENCE2> [SEP]"
with `[PAD]` tokens added at the end as needed.

**Notice below how `token_type_ids` denotes where the first sequence ends and the next begins.**

Data preprocessing will require us to set up the two  training tasks of NSP and MLM.

In [73]:
test = (sample[1], sample[3])

# test = "\n"
print(f"{test=}")
#automatically adds [CLS] before first sentence and [SEP] after each sentence
#this is how Bert separates two sentences in the input
tokenized = tokenizer.encode_plus(test[0], test[1], add_special_tokens = True, return_tensors = "pt")
for key, val in tokenized.items():
  print(f"__{key}__")
  print(f"len={len(val)}")
  print(val)

test=(' = Valkyria Chronicles III = \n', ' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . <unk> the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " <unk> Raven " . \n')
__input_ids__
len=1
tensor([[    1,    33,  7754,  7833,  2907,    33,     2,  2957,  5195,   746,
          7754,    23,    30,    32,   396,    34,  7833,    12,  2709,    30,
           236,   234,     4,    16,  1244,    18,  7754,   394,   381,  9510,
   

In [None]:
def new_article(para):
  #just checks if it is a new article
  return para.strip().startswith("= ") and not para.strip().startswith("= =")

def new_heading(para):
  #checks if it is a new heading
  #slightly different from the new article
  return para.strip().startswith("=")

def split_sentences(para):
  sentences = para.split('. ')

def article_sentence_pairs(article_sentences):
  #See note below
  sentence_pairs = []
  for i, sentence in enumerate(article_sentences):
    if i == len(article_sentences)-1:
      break
    else:
      sentence_pairs.append((sentence, article_sentences[i+1]))
  return sentence_pairs

def collect_sentence_pairs(dataset_iter):
  articles = []
  current_article = []
  for para in tqdm(dataset_iter):
    if new_article(para):
      #new article, stop collecting sentences
      if len(current_article) > 0:
        articles += article_sentence_pairs(current_article)
      current_article = []
      continue
    if para.strip() == '' or new_heading(para):
      #new heading or empty line, skip and continue collecting
      continue
    current_article += para.strip().split('. ')
  return articles


NOTE: there are a few ways we could have split this dataset. I chose to go with one that splits at `". "` for this approach but we could have tokenized the data first and split according to maximum sequence length.

In [None]:
#proper pipeline management, since we iterated once on the old data_train object we reset iteration
gen = iter(dataset_class(DATA_DIRECTORY, split = "train"))
sentence_pairs = collect_sentence_pairs(gen)
len(sentence_pairs)

In [None]:
sentence_pairs[900000:900010]

In [None]:
#Persist sentence pairs to file
os.makedirs(os.path.join(DATA_DIRECTORY, "datasets", "pairs_data"), exist_ok = True)


SENTENCE_PAIRS_LOCATION = os.path.join(TOKENS_LOCATION, "pairs")
os.makedirs(SENTENCE_PAIRS_LOCATION)


#pickle 200k sentence pairs at a time into various pickle files
def pickle_pairs(pairs, chunk_size = 200_000):
  i = 0
  while chunk_size*(i+1) < len(pairs):
    with open(os.path.join(SENTENCE_PAIRS_LOCATION, f"{i}.pkl"), "wb") as f:
      pickle.dump(pairs[chunk_size*i:chunk_size*(i+1)], f)
    i+=1

pickle_pairs(sentence_pairs)


So we see now that we are able to make pairs of sentences next to one another. Note that our particular method makes it so the first sentence from the immediately following paragraph is treated as a "next sentence" but the first sentence of the following article is not.

###The Two Training Objectives

  The following examples are straight from the paper and used to illustrate the NSP objective but can be used to explain the MLM objective as well.
    
    Input: [CLS] the man went to [MASK] store [SEP] he bought a gallon [MASK] milk [SEP]
    Label: IsNext

    Input: [CLS] the man went to [MASK] store [SEP] penguin [MASK] are flight ##less birds [SEP]
    Label: NotNext

####Masked Language Model
The **Masked Language Model** task masks out a given ratio of the tokens (and with some small probability substitutes with a random token) in the inputs and asks the model to predict the tokens that were masked out.

####Next Sentence Prediction

The **Next Sentence Prediction** task takes sentence pairs and creates a balanced classification task, where half of the time the following sentence is actually the next sentence and half of the time it is a random sentence. Then the model is trained to correctly predict whether or not the second sentence is the actual next sentence.

We have already done half of the work for this task by creating positive next sentence pairs for each sentence in the text. All we have to do is replace the second sentence in each pair with a random sentence half of the time.

In [None]:
#We will create a custom Dataset class for our purposes
class BERTDataset(Dataset):
  def __init__(self, sentence_pairs, tokenizer, max_len=512):
    self.sentence_pairs = sentence_pairs
    self.num_pairs = len(sentence_pairs)
    self.tokenizer = tokenizer #we may not need this
    self.max_len = max_len #we may not need this

  def __len__(self):
    return self.num_pairs

  #The key component of a Dataset is the __getitem__ function
  def __getitem__(self, idx):
    sentence1, sentence2, isNextLabel = self.get_nsp_entry(idx)

  def get_nsp_entry(self, idx):
    #Implement NSP randomization here
    sent1, sent2 = self.sentence_pairs[idx]
    if random.random() >= 0.5:
      #this is the case where we give positive nsp example
      return sent1, sent2, 1
    else:
      return sent1, self.get_non_next_sentence(idx), 0

  def get_non_next_sentence(self, idx):
    random_idx = random.randrange(self.num_pairs)
    while random_idx == idx:
      """
      this is just here for the small chance that
      our random index maps to the same one and gives
      us a false pair where the actual next sentence
      is mislabeled as NotNext
      """
      random_idx = random.randrange(self.num_pairs)
    return self.sentence_pairs[random_idx][1]



#####Example using the Transformers Library for MLM

We can use the `transformers.DataCollatorForLanguageModeling` class to handle the masking for us.

We see [in the documentation](https://github.com/huggingface/transformers/blob/b71f20a7c9f3716d30f6738501559acf863e2c5c/src/transformers/data/data_collator.py#L751C1-L751C108) that exactly like in the paper, `DataCollatorForLanguageModeling` will mask tokens 80% of the time, replace with a random token 10% of the time, and leave as is 10% of the time. This is all conditional on the `mlm_probability` value passed into the collator which is `0.15` in the paper.

In [86]:
collator = DataCollatorForLanguageModeling(tokenizer = tokenizer, mlm = True, mlm_probability = 0.15)

In [91]:
#We recreate our tokenization example from earlier,
#we just want to create a sentence pair for illustration
feats = tokenizer.encode_plus(sample[1], sample[3], add_special_tokens = True, return_tensors = "pt")

In [92]:
masked = collator([feats])
masked

{'input_ids': tensor([[[    1,    33,  7754,  7833,  2907,    33,     2,  2957,  5195,   746,
           7754,    23,    30,    32,   396,     3,  7833,    12, 15373,    30,
            236,   234,     4,    16,  1244,    18,  7754,   394,   381,  9510,
             23,    13,    16,  4627,  3155,   403,   428,  7754,  7833,  2907,
           2364,     3,    16,   445,    42,     3,  1554,    36,    17,    36,
           2202,  1380,   741,  1919,   450,  5160,   252,   399,  2639,    18,
           5480,   424,   381,  4501, 10289,    18,  1146,     3,  1425,  1719,
            391,  1826,    16,   444,     3,   381,  1361,   741,   391,     3,
           7754,   885,    18,    32,   396,    34,     3,  1145,  8559,   394,
          10526,   399,  1511,    36,    17,    36,   689,  5059,   428,   577,
          12405,    16,   381,  1422,  3291,  6416,     3,     3,   573,   741,
            399,  4397,   381,     3, 18266,     6,    16,    42, 11551,  1594,
           3147,  5034,   

In [93]:
tokenizer.decode(masked['input_ids'].squeeze())

'[CLS] = valkyria chronicles iii = [SEP] senjo no valkyria 3 : < unk [MASK] chronicles ( criticizing : 戦 場 [UNK], lit. valkyria of the battlefield 3 ), commonly referred to as valkyria chronicles iii outside [MASK], is a [MASK] role @ - @ playing video game developed by sega and media. vision for the playstation portable. released [MASK] january 2011 in japan, it [MASK] the third game in [MASK] valkyria series. < unk > [MASK] same fusion of tactical and real @ - @ time gameplay as its predecessors, the story runs parallel [MASK] [MASK] first game and follows the [MASK] nameless ", a penal military unit serving the [MASK] of gall [MASK] during the second europan war [MASK] performanced black operations and are pitted [MASK] the [MASK] unit " < unk > raven ". [SEP]'

In [94]:
tokenizer.decode(feats['input_ids'].squeeze())

'[CLS] = valkyria chronicles iii = [SEP] senjo no valkyria 3 : < unk > chronicles ( japanese : 戦 場 [UNK], lit. valkyria of the battlefield 3 ), commonly referred to as valkyria chronicles iii outside japan, is a tactical role @ - @ playing video game developed by sega and media. vision for the playstation portable. released in january 2011 in japan, it is the third game in the valkyria series. < unk > the same fusion of tactical and real @ - @ time gameplay as its predecessors, the story runs parallel to the first game and follows the " nameless ", a penal military unit serving the nation of gallia during the second europan war who perform secret black operations and are pitted against the imperial unit " < unk > raven ". [SEP]'

##Model Architecture

##Training