# BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

Implementation of the "[BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf)" paper.

## Contents
1. [Libraries](#libs)
2. [Dataset](#dataset)

    2.1. [Initializing The SpaCy English Tokenizer](#spacy)

    2.2. [Creating a vocab](#vocab)

    2.3. [Masked Language Modeling Dataset](#mlmdataset)

3. [BERT](#bert)

4. [Training session](#training)

<a id="libs"></a>
## 1. Libraries

In [11]:
import os
import re
import math
import random
from typing import List, Dict, Tuple, Generator, Any

import spacy
import torch
import torchtext
import numpy as np
import pandas as pd
from torch import nn
from spacy.symbols import ORTH
from torch.utils.data import random_split
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import Accuracy


# Controlling the randomness in PyTorch and NumPy.
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.backends.cudnn.benchmark = True
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)

<a id="dataset"></a>
## 2. The Dataset

I am going to use the text from [Mary Shelley's Frankenstein](https://www.gutenberg.org/ebooks/84).
Firstly, the whole text is saved into a string variable (not a scalable way of managing the data, but this corpus is really small). Then, I'll define a SpaCy Tokenizer. It will be used for the tokenization of the whole book. 
Afterwards, a Vocabulary is created based on the tokens.
Finally, using the Tokenizer and the Vocabulary, a PyTorch Dataset is defined.

In [12]:
DATASET_PATH = os.path.join(
    "..",
    "..",
    "nlp",
    "datasets",
    "frankenstein",
    "frankenstein.txt"
)
book_text = open(DATASET_PATH, "r", encoding="utf8").read()

<a id="spacy"></a>
### 2.1. Initializing The SpaCy English Tokenizer
The SpaCy Tokenizer is updated with multiple special tokens:
- `[CLS]` - classification token; used as a begining of sequence token
- `[SEP]` - separation token; used for end of sequence and as a separation in NSP pre-training
- `[MASK]` - mask token; used when the tokens are masked during MLM pre-training
- `[PAD]` - padding token
- `[UNK]` - unknown token

In [13]:
SPACY_TOKENIZER = spacy.load("en_core_web_sm")

SPECIAL_TOKENS = [
    "[CLS]",
    "[SEP]",
    "[MASK]",
    "[PAD]",
    "[UNK]"
]
# Adding a special rule for each special token.
# If we don't do that, these tokens will be disregarded as such.
for spec_token in SPECIAL_TOKENS:
    rule = [{ORTH: spec_token}]
    SPACY_TOKENIZER.tokenizer.add_special_case(spec_token, rule)

def en_tokenizer(text):
    return [token.orth_ for token in SPACY_TOKENIZER(text)]

<a id="vocab"></a>
### 2.2. Creating a vocab

In [14]:
def prepare_text(text: str) -> str:
    """Preparing the text for tokenization.
    It includes:
    1. Surrounding punctuation with whitespace
    2. Converting multiple spaces into a single one

    Args:
        text (str): A text.

    Returns:
        str: The parsed text.
    """
    pattern = r"([.,!?:;]+)"
    text = re.sub(pattern, r" \1 ", text)

    pattern = r"\s+"
    text = re.sub(pattern, " ", text)

    return text

In [None]:
def iterate_corpus(tokens: List[str], seq_size: int) -> Generator[str]:
    """Iterate through the corpus and yielding a batch of sequences.

    Args:
        tokens (List[str]): Word tokens.
        seq_size (int): Length of a sequence batch.

    Yields:
        Generator[str]: Generated batch.
    """
    for i in range(0, len(tokens) - seq_size):
        yield tokens[i:i + seq_size]


book_text = prepare_text(book_text)
tokens = en_tokenizer(book_text)

vocab = build_vocab_from_iterator(
    iterator=iterate_corpus(tokens, seq_size=100),
    specials=SPECIAL_TOKENS
)
vocab.set_default_index(vocab["[UNK]"])
print(f"Vocab size: {len(vocab)}")

<a id="mlmdataset"></a>
### 2.3. Masked Language Modeling Dataset

Let's see what the authors of the paper said about MLM:
> we simply mask some percentage of the input
> tokens at random, and then predict those masked
> tokens. We refer to this procedure as a “masked
> LM” (MLM), although it is often referred to as a
> Cloze task

They then mask $15%$ of the tokens of each sequence:
> In all of our experiments, we mask 15% of all WordPiece to-
> kens in each sequence at random. 

While also having cases in which the `[MASK]` token is not used at all:
> If the i-th token is chosen, we replace
> the i-th token with (1) the [MASK] token 80% of
> the time (2) a random token 10% of the time (3)
> the unchanged i-th token 10% of the time.

This $80$-$10$-$10$ configuration is added because that's pre-training. When the model
is fine-tuned there may be many cases, in which we would not like to mask the input at all.

In [None]:
class MlmDataset(Dataset):

    def __init__(
        self, 
        text: str, 
        tokenizer, 
        vocab, 
        seq_len: int, 
        mask_token="[MASK]", masked_frac=0.15,
        pad_token="[PAD]"
    ):
        self._text = text
        self._tokens = tokenizer(text)
        self._vocab = vocab
        self._itos = vocab.get_itos()
        self._seq_len = seq_len

        self._mask_token = mask_token
        self._masked_frac = masked_frac
        self._pad_token = pad_token

        self.x, self.y = self._getinout()

    def __getitem__(self, idx: int):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

    def _getinout(self) -> Tuple[torch.LongTensor, torch.LongTensor]:
        """Get the input and output PyTorch tensors.
        - 80% of the times replace 15% of the tokens with `self._mask_token`
        - 10% of the times replace 15% of the tokens with random ones
        - 10% of the times the input sequence is not changed

        Returns:
            Tuple[torch.LongTensor, torch.LongTensor]: Input and output tensors.
        """
        x = []
        y = []

        for i in range(0, len(self._tokens) - self._seq_len - 1):
            sequence = self._tokens[i:i + self._seq_len]

            # The random token indices that will be masked.
            random_indices = random.sample(
                range(0, len(sequence)),
                int(len(sequence) * self._masked_frac)
            )
            random_indices.sort()

            percentile = random.uniform(0, 1)

            # 80% of the times - replace some of the tokens with [MASK].
            if percentile <= 0.8:
                replacement_tokens = [self._mask_token] * len(random_indices)

            # 10% of the times - replace some of the tokens with random tokens.
            elif percentile > 0.8 and percentile <= 0.9:
                replacement_tokens = [
                    self._itos[random.randint(0, len(self._vocab) - 1)] 
                    for _ in range(len(random_indices))
                ]

            # 10% of the times - don't change the sequence of tokens.
            elif percentile > 0.9:
                replacement_tokens = None

            masked_tokens = ["[CLS]"] + sequence + ["[SEP]"]
            sequence = ["[CLS]"] + self._replace_tokens(
                sequence, random_indices, 
                replacement_tokens=replacement_tokens
            ) + ["[SEP]"]

            x.append([self._vocab[token] for token in sequence])
            y.append([
                self._vocab[token]
                for token in masked_tokens
            ])

        return torch.LongTensor(x), torch.LongTensor(y)
    
    def _replace_tokens(self, sequence: List[str], indices: List[int], replacement_tokens: List[str]) -> List[str]:
        """Replace tokens in `sequence` with the `replacement_tokens`, based on `indices`.

        Args:
            sequence (List[str]): The input sequence.
            indices (List[int]): The indices of the tokens that should be replaced.
            replacement_tokens (List[str]): The tokens that are going to replace the original ones.

        Returns:
            List[str]: The transformed sequence.
        """
        repl_token_i = 0

        # If there are no replacement tokens, the sequence stays the same.
        if replacement_tokens:
            for i in range(len(sequence)):
                if i in indices:
                    sequence[i] = replacement_tokens[repl_token_i]
                    repl_token_i += 1

        return sequence
    
    def _pad_output(self, masked_sequence: List[str]) -> List[str]:
        # Since this dataset is set up for MLM, there is no chance that the output
        # sequence is longer than the input one (i.e. in the BERT paper, the MLM
        # output sequence, without the padding, is 15% of the input length), we 
        # directly add padding.
        padding_size = self._seq_len - len(masked_sequence)
        return ["[CLS]"] + masked_sequence + ["[SEP]"] + [self._pad_token] * padding_size


dataset = MlmDataset(
    text=book_text, 
    tokenizer=en_tokenizer, 
    vocab=vocab,
    seq_len=20
)

# Index to string
ITOS = {vocab[token]: token for token in vocab.vocab.itos_}
# String to index
STOI = vocab

batch = dataset[:10]

x, y = batch[0][1].tolist(), batch[1][1].tolist()
print("Input:", [ITOS[el] for el in x])
print("Target:", [ITOS[el] for el in y])
print()
print("Input shape:", dataset[:3][0].shape)
print("Target shape:", dataset[:3][1].shape)

Input: ['[CLS]', 'Gutenberg', 'eBook', 'of', 'Frankenstein', ',', 'by', 'Mary', 'Wollstonecraft', 'Shelley', 'This', 'eBook', 'is', 'for', 'the', 'use', 'of', 'anyone', 'anywhere', 'in', 'the', '[SEP]']
Target: ['[CLS]', 'Gutenberg', 'eBook', 'of', 'Frankenstein', ',', 'by', 'Mary', 'Wollstonecraft', 'Shelley', 'This', 'eBook', 'is', 'for', 'the', 'use', 'of', 'anyone', 'anywhere', 'in', 'the', '[SEP]']

Input shape: torch.Size([3, 22])
Target shape: torch.Size([3, 22])


<a id="bert"></a>
## 3. BERT

The BERT model uses only the Encoder of the classic Transformer architecture in the paper "[Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf)" paper. Let's go through the architecture layer by layer.

- *Embedding layers* - BERT uses three types of embedding layers. *Token Embedding* - used to embed the semantic meaning of the tokens, based on their usage; *Segment Embedding* - to denote when there are different segments in the input; *Position Embedding* - encodes the position of each token in the sequence

- *Encoder layers* - Transformer Encoder layers. This is what the paper authors say about the configuration of these layers:
    > We primarily report results on two model sizes:
    > $BERT_{BASE}$ (L=12, H=768, A=12, Total Parameters=110M) and $BERT_{LARGE}$ (L=24, H=1024,
    > A=16, Total Parameters=340M).

- *Feed-forward layer* - Used to map the output Encoder Tensor to the expected output size (vocab size); for the NSP pre-training task, only the pooler output is used. Pooler output means that we only take the first element of the output vector, which corresponds to the `[CLS]` token

In [None]:
class Bert(nn.Module):
    
    def __init__(
        self, 
        d_model: int, nheads: int, num_layers: int,
        vocab_size: int,
        max_seq_len=128,
        dropout=0.1
    ):
        super().__init__()

        assert d_model % nheads == 0, "'d_model' has to be divisible by 'nheads'."

        # Token embedding
        self.token_embed = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=d_model
        )

        # Segmentation embedding
        # The difference here is that we are setting 'num_embeddings' to 3 because
        # we will use 0 as padding index, 1 as sent1 index and 2 as sent2 index.
        self.segm_embed = nn.Embedding(
            num_embeddings=3, embedding_dim=d_model
        )

        self.pos_embed = nn.Embedding(
            num_embeddings=max_seq_len, 
            embedding_dim=d_model
        )

        # Dropout after all embeddings
        self.dropout1 = nn.Dropout(p=dropout)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nheads, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=enc_layer, num_layers=num_layers
        )
        self.dropout2 = nn.Dropout(p=dropout)

        self.project = nn.Linear(d_model, vocab_size)

        self.init_weights()

    def forward(self, x: torch.LongTensor, segmentations: torch.LongTensor = None):
        # x shape: (batch_size, seq_len)
        seq_len = x.shape[1]
        position_ids = torch.arange(seq_len, device=x.device).long()

        # Input embedding
        x = self.token_embed(x)
        # Positional embeding
        x += self.pos_embed(position_ids)
        
        # Optional segmentational encoding/embedding.
        if segmentations is not None:
            x += self.segm_embed(segmentations)

        # x shape: (batch_size, seq_len, d_model)

        out = self.encoder(self.dropout1(x))
        # out shape: (batch_size, seq_len, d_model)

        out = self.project(out)
        # out shape: (batch_size, seq_len, vocab_size)

        return out

    def init_weights(self):
        initrange = 0.1
        (
            torch.nn.init.uniform_(module, a=-initrange, b=initrange) 
            for module in self.encoder.modules()
        )


model = Bert(
    d_model=64,
    nheads=8,
    num_layers=2,
    vocab_size=100
)

x = torch.randint(low=0, high=99, size=(2, 5))
segmentations = torch.randint(low=0, high=3, size=(2, 5))

print("Input shape:", x.shape)
print("Segmentations shape:", segmentations.shape)
y_pred = model(x, segmentations=None)
print("Output shape:", y_pred.shape)

Input shape: torch.Size([2, 5])
Segmentations shape: torch.Size([2, 5])
Output shape: torch.Size([2, 5, 100])


<a id="training"></a>
## 4. Training session

In [None]:
class TrainingSession:

    def __init__(
        self, 
        model: nn.Module, 
        loss: nn.Module, 
        optimizer: torch.optim.Optimizer,
        itos: Dict[int, str],
        device="cpu"
    ):
        self._model = model.to(device)
        self._loss_func = loss
        self._opt = optimizer
        self._itos = itos
        self.device = device

    def start(
        self, 
        train_dataset: Dataset, valid_dataset: Dataset, 
        epochs: int, batch_size: int,
        fixed_input: torch.LongTensor = None,
        metrics: Dict[str, Any] = None,
        save_model: bool = True,
        model_path: str = "./model.pt"
    ):
        train_dl = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0
        )
        valid_dl = DataLoader(
            dataset=valid_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=0
        )
        self._metrics = metrics
        self._metric_results = {}

        for epoch in range(epochs):
            self._train_epoch(train_dl)
            self._valid_epoch(valid_dl)

            print(f"Epoch: {epoch + 1}")
            print(self._metric_results)

            if fixed_input is not None:
                self._print_fixed_pred(fixed_input)

            # Saving the model for the epoch.
            if save_model:
                print(f"Saving model to '{model_path}'...")
                torch.save(self._model.state_dict(), model_path)

            print()

    def _train_epoch(self, dataloader: DataLoader):
        for x, y in dataloader:
            x, y = x.to(self.device), y.to(self.device)

            y_pred = self._model(x)
            self._opt.zero_grad()

            y_pred = y_pred.reshape(-1, y_pred.shape[-1])
            # y_pred shape: [batch_size * seq_len, vocab_size]
            y = y.reshape(-1)
            # y shape: [batch_size * seq_len]

            loss = self._loss_func(y_pred, y)
            loss.backward()

            self._opt.step()

        # Adding the loss to the metrics.
        self._metric_results["Training Loss"] = loss.item()
        self._metric_results = self._calc_metrics(y_pred, y, type_="Training")

        return self._metric_results

    def _valid_epoch(self, dataloader: DataLoader):
        self._model.eval()

        with torch.no_grad():
            for x, y in dataloader:
                x, y = x.to(self.device), y.to(self.device)

                y_pred = self._model(x)

                y_pred = y_pred.reshape(-1, y_pred.shape[-1])
                y = y.reshape(-1)

                loss = self._loss_func(y_pred, y)

        self._model.train()

        self._metric_results["Validation Loss"] = loss.item()
        self._metric_results = self._calc_metrics(y_pred, y, type_="Validation")

        return self._metric_results

    def _print_fixed_pred(self, fixed_input: torch.LongTensor):
        self._model.eval()

        with torch.no_grad():
            pred = self._model(fixed_input.to(self.device))
            pred = pred.argmax(-1)

        print(f"Input: {[self._itos[int(idx)] for idx in fixed_input[0]]}")
        print(f"Prediction {[self._itos[int(idx)] for idx in pred[0]]}")

        self._model.train()

    def _calc_metrics(self, y_pred: torch.Tensor, y: torch.LongTensor, type_: str):
        for name, metric in self._metrics.items():
            self._metric_results[f"{type_} {name}"] = float(metric(y_pred, y))

        return self._metric_results


# Dataset params:
TRAIN_FRAC = 0.8
SEQ_LEN = 20
FIXED_INPUT = en_tokenizer("[CLS] Time is [MASK] [SEP]")
FIXED_INPUT = [
    STOI[token]
    for token in FIXED_INPUT
]
FIXED_INPUT = torch.LongTensor(FIXED_INPUT).unsqueeze(0)


# Model hyperparams:
# One of the learning rates recommended in the paper.
L_RATE = 3e-5
EPS = 1e-8
BATCH_SIZE = 64
EPOCHS = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_MODEL = True
MODEL_FILENAME = "2-bert.pt"
MODEL_PATH = os.path.join("..", "models", MODEL_FILENAME)

HIDDEN_SIZE = 768
NHEADS = 12
LAYERS = 12

dataset = MlmDataset(
    text=book_text, 
    tokenizer=en_tokenizer, 
    vocab=vocab,
    seq_len=SEQ_LEN
)
train_size = int(TRAIN_FRAC * len(dataset))

train_dataset, valid_dataset = random_split(
    dataset=dataset,
    lengths=[train_size, len(dataset) - train_size]
)


bert = Bert(
    d_model=HIDDEN_SIZE,
    nheads=NHEADS,
    num_layers=LAYERS,
    vocab_size=len(vocab)
)
print(bert)
optimizer = torch.optim.AdamW(model.parameters(), lr=L_RATE, eps=EPS)
loss = nn.CrossEntropyLoss(ignore_index=vocab["[PAD]"])


session = TrainingSession(
    model=bert,
    loss=loss,
    optimizer=optimizer,
    itos=ITOS,
    device=DEVICE
)

print(f"Starting training session on device: {DEVICE}...")
# TODO: Use a larger dataset for the pretraining process.
# Currently I use a small dataset
# session.start(
#     train_dataset=train_dataset,
#     valid_dataset=valid_dataset,
#     epochs=EPOCHS,
#     batch_size=BATCH_SIZE,
#     fixed_input=FIXED_INPUT,
#     metrics={
#         "Accuracy": Accuracy(
#             task="multiclass",
#             num_classes=len(vocab)
#         ).to(DEVICE)
#     },
#     save_model=SAVE_MODEL,
#     model_path=MODEL_PATH
# )
print("Training session has finished!")

Bert(
  (token_embed): Embedding(7829, 768)
  (segm_embed): Embedding(3, 768)
  (pos_embed): Embedding(128, 768)
  (dropout1): Dropout(p=0.1, inplace=False)
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (linear1): Linear(in_features=768, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=768, bias=True)
        (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_