# Bidirectional Encoder Representations from Transformers (BERT)

This notebook show a basic implementation of BERT pre-training using the wikipedia dataset.

In [None]:
# black formatting with jupyter-black
import jupyter_black

jupyter_black.load(
    lab=True,
    line_length=140,
)

In [None]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
# import libaries
import torch
import re

import numpy as np
import seaborn as sns
import pandas as pd
import torch.nn.functional as F
from torch import nn
import multiprocessing as mp

from typing import List
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset, TensorDataset
import re

from tempfile import NamedTemporaryFile

from datasets import load_dataset


from utils import text_preprocessing

In [None]:
# download wikipedia dataset
data = load_dataset("wikipedia", "20220301.en", split="train[0:500]", trust_remote_code=True).to_pandas()
# data = load_dataset("karpathy/tiny_shakespeare", split="train", trust_remote_code=True).to_pandas()

In [None]:
# split each text by "."
texts = []

for t_id, t in enumerate(data.text.tolist()):
    t = t.split(".")
    for sentence_id, sentence in enumerate(t):
        texts.append({"paragraph_id": t_id, "sentence_id": sentence_id, "text": sentence})

data = pd.DataFrame(texts)

# Tokenizer

We have implemented a Byte-Per Encoding Tokenizer. However, this python implementation is really slow and so we will use a transformers implementation of Word Piece Tokenizer 

In [None]:
from tokenizers import Tokenizer, models, trainers
from tokenizers.pre_tokenizers import BertPreTokenizer

In [None]:
# Initialize a tokenizer
tokenizer = Tokenizer(models.WordPiece(unk_token="[UNK]"))

# Set the pre-tokenizer to a custom one
tokenizer.pre_tokenizer = BertPreTokenizer()

# Enable padding for the tokenizer
# tokenizer.enable_padding()

# Initialize a trainer with desired parameters
trainer = trainers.WordPieceTrainer(vocab_size=30000, special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"])

# preprocess data
data.text = data.text.apply(text_preprocessing)

# Load your training data into a list of strings
train_data = data.text.tolist()

# Train the tokenizer
tokenizer.train_from_iterator(
    train_data,
    trainer=trainer,
)

In [None]:
tokenizer.encode("praying, she's a good person").tokens

In [None]:
# count number of words
data["text_length"] = data.text.apply(lambda x: len(tokenizer.encode(x).tokens))
data = data.query("text_length>=5 and text_length<=60").reset_index(drop=True)

In [None]:
# plot number of words distribution
sns.histplot(data=data, x="text_length", kde=True)

## Create a dataset with  positive and negative samples for the Next Sentence Prediction (NSP) task

Now we create the dataset that will be sued for Next Sentence Prediction. In NSP, the model is given two sentences and must predict whether the second sentence logically follows the first sentence in a coherent narrative or if the two sentences are unrelated. This task helps the model understand the relationships between sentences and improve its comprehension of text flow, context, and continuity.

In [None]:
def get_nsp_sample(dataset, idx):
    """
    Generates a sample for the Next Sentence Prediction (NSP) task.

    Arguments:
    ----------
        dataset (pd.DataFrame): The dataset containing text, paragraph_id, and sentence_id columns.
        idx (int): The index of the sentence to be used as the first sentence (sentence_a).

    Returns:
    --------
        tuple: A tuple containing:
            - sentence_a (str): The first sentence.
            - sentence_b (str): The second sentence, which can either be a coherent follow-up sentence or an unrelated one.
            - nsp_label (int): The label indicating if sentence_b is a follow-up (1) or not (0).
    """
    sentence_a = dataset.loc[idx, "text"]
    paragraph_id_a = dataset.loc[idx, "paragraph_id"]
    sentence_id_a = dataset.loc[idx, "sentence_id"]

    prob_nsp = np.random.random()

    if prob_nsp >= 0.5:
        sentence_b, nsp_label = get_positive_pair(dataset, paragraph_id_a, sentence_id_a)
    else:
        sentence_b, nsp_label = get_negative_pair(dataset, paragraph_id_a, sentence_id_a, hard_negative=False)

    if sentence_b is None:
        sentence_b, nsp_label = get_negative_pair(dataset, paragraph_id_a, sentence_id_a, hard_negative=False)

    return sentence_a, sentence_b, nsp_label


def get_positive_pair(dataset, paragraph_id_a, sentence_id_a):
    """
    Retrieves a positive sentence pair for the NSP task, where the second sentence follows the first.

    Arguments:
    ----------
        dataset (pd.DataFrame): The dataset containing text, paragraph_id, and sentence_id columns.
        paragraph_id_a (int): The paragraph ID of the first sentence.
        sentence_id_a (int): The sentence ID of the first sentence.

    Returns:
    --------
        tuple: A tuple containing:
            - sentence_b (str or None): The next sentence if it exists, otherwise None.
            - nsp_label (int): The label indicating this is a follow-up sentence (1).
    """
    nsp_label = 1
    sentence_id_b = sentence_id_a + 1

    try:
        sentence_b = dataset.query(f"paragraph_id == {paragraph_id_a} and sentence_id == {sentence_id_b}").text.iloc[0]
    except:
        sentence_b = None

    return sentence_b, nsp_label


def get_negative_pair(dataset, paragraph_id_a, sentence_id_a, hard_negative=False):
    """
    Retrieves a negative sentence pair for the NSP task, where the second sentence does not follow the first.

    Arguments:
    ----------
        dataset (pd.DataFrame): The dataset containing text, paragraph_id, and sentence_id columns.
        paragraph_id_a (int): The paragraph ID of the first sentence.
        sentence_id_a (int): The sentence ID of the first sentence.
        hard_negative (bool): If True, the second sentence is chosen from the same paragraph but is not the next sentence.
                              If False, the second sentence is chosen from a different paragraph.

    Returns:
    --------
        tuple: A tuple containing:
            - sentence_b (str): The unrelated sentence.
            - nsp_label (int): The label indicating this is not a follow-up sentence (0).
    """
    nsp_label = 0

    if hard_negative:
        sentence_b = dataset.query(f"paragraph_id == {paragraph_id_a} and sentence_id != {sentence_id_a}").sample(1).text.iloc[0]
    else:
        sentence_b = dataset.query(f"paragraph_id != {paragraph_id_a}").sample(1).text.iloc[0]

    return sentence_b, nsp_label

In [None]:
# prepare dataset for NSP
sentence_a = []
sentence_b = []
label = []

for i in range(len(data)):

    a, b, l = get_nsp_sample(data, i)

    sentence_a.append(a)
    sentence_b.append(b)
    label.append(l)

data["sentence_a"] = sentence_a
data["sentence_b"] = sentence_b
data["label"] = label

In [None]:
data.label.value_counts()

# Custom Dataset And Data Collator

We define a custom data that prepare data MLM and NSP pre-trainig

In [None]:
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import torch

In [None]:
class CustomDataset(Dataset):
    """
    A custom dataset class for handling sentence pairs and generating masked language model (MLM) and
    next sentence prediction (NSP) labels.

    Attributes:
        dataset (pd.DataFrame): The dataset containing sentence pairs and labels.
        hard_negative (bool): A flag to indicate if hard negative sampling is used.
        sep_token_id (list): Token ID for the [SEP] token.
        cls_token_id (list): Token ID for the [CLS] token.
    """

    def __init__(self, dataset: pd.DataFrame, hard_negative=False) -> None:
        """
        Initializes the CustomDataset with the provided dataset and optional hard negative sampling flag.

        Arguments:
        ----------
            dataset (pd.DataFrame): The input dataset containing sentence pairs and labels.
            hard_negative (bool): A flag indicating whether to use hard negative sampling.
        """
        super().__init__()

        self.dataset = dataset
        self.hard_negative = hard_negative
        self.sep_token_id = tokenizer.encode("[SEP]").ids
        self.cls_token_id = tokenizer.encode("[CLS]").ids

    def __len__(self) -> int:
        """
        Returns the number of samples in the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves the sample at the specified index and processes it to generate input IDs, MLM labels,
        and NSP label.

        Arguments:
        ----------
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the input IDs, MLM labels, and NSP label.
        """
        # Retrieve sentence pairs and NSP label from the dataset
        sentence_a = self.dataset.sentence_a.tolist()[idx]
        sentence_b = self.dataset.sentence_b.tolist()[idx]
        nsp_label = self.dataset.label.tolist()[idx]

        # Construct the input sentence with special tokens
        sentence = f"[CLS] {sentence_a} [SEP] {sentence_b}"
        ids, mlm_labels = self.get_masked_sentence(sentence)

        return ids, mlm_labels, nsp_label

    def get_masked_sentence(self, sentence):
        """
        Applies masking to the input sentence to generate MLM labels.

        Arguments:
        ----------
            sentence (str): The input sentence with special tokens.

        Returns:
            tuple: A tuple containing the masked input IDs and the original input IDs as MLM labels.
        """
        # Encode the sentence into token IDs
        encoded_sentence = tokenizer.encode(sentence, add_special_tokens=False)
        ids = np.array(encoded_sentence.ids)

        # Determine the number of tokens to mask (15% of the total tokens)
        n_mask_tokens = max(1, round(len(ids) * 0.15))

        # Identify the index of the [SEP] token
        sep_index = (ids == tokenizer.token_to_id("[SEP]")).argmax()

        # Create a list of candidate tokens for masking, excluding special tokens
        candidate_mask = np.arange(len(ids))
        candidate_mask = candidate_mask[~np.isin(candidate_mask, [0, sep_index])]

        # Randomly select tokens to mask
        selected_tokens = np.random.choice(candidate_mask, size=n_mask_tokens, replace=False)
        mlm_labels = ids.copy()

        # Replace selected tokens with the [MASK] token ID
        ids[selected_tokens] = tokenizer.token_to_id("[MASK]")

        return ids.tolist(), mlm_labels.tolist()

In [None]:
class CustomDataset(Dataset):
    """
    A custom dataset class for handling sentence pairs and generating masked language model (MLM) and
    next sentence prediction (NSP) labels.

    Attributes:
        dataset (pd.DataFrame): The dataset containing sentence pairs and labels.
        hard_negative (bool): A flag to indicate if hard negative sampling is used.
        sep_token_id (list): Token ID for the [SEP] token.
        cls_token_id (list): Token ID for the [CLS] token.
    """

    def __init__(self, dataset: pd.DataFrame, hard_negative=False) -> None:
        """
        Initializes the CustomDataset with the provided dataset and optional hard negative sampling flag.

        Arguments:
        ----------
            dataset (pd.DataFrame): The input dataset containing sentence pairs and labels.
            hard_negative (bool): A flag indicating whether to use hard negative sampling.
        """
        super().__init__()

        self.dataset = dataset
        self.hard_negative = hard_negative
        self.sep_token_id = tokenizer.encode("[SEP]").ids
        self.cls_token_id = tokenizer.encode("[CLS]").ids

    def __len__(self) -> int:
        """
        Returns the number of samples in the dataset.

        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.dataset)

    def __getitem__(self, idx):
        """
        Retrieves the sample at the specified index and processes it to generate input IDs, MLM labels,
        and NSP label.

        Arguments:
        ----------
            idx (int): The index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the input IDs, MLM labels, and NSP label.
        """
        # Retrieve sentence pairs and NSP label from the dataset
        sentence_a = self.dataset.sentence_a.tolist()[idx]
        sentence_b = self.dataset.sentence_b.tolist()[idx]
        nsp_label = self.dataset.label.tolist()[idx]

        # Construct the input sentence with special tokens
        sentence = f"[CLS] {sentence_a} [SEP] {sentence_b}"
        ids, mlm_labels = self.get_masked_sentence(sentence)

        return ids, mlm_labels, nsp_label

    def get_masked_sentence(self, sentence):
        """
        Applies masking to the input sentence to generate MLM labels.

        Arguments:
        ----------
            sentence (str): The input sentence with special tokens.

        Returns:
            tuple: A tuple containing the masked input IDs and the original input IDs as MLM labels.
        """
        # Encode the sentence into token IDs
        encoded_sentence = tokenizer.encode(sentence, add_special_tokens=False)
        ids = np.array(encoded_sentence.ids)

        # Determine the number of tokens to mask (15% of the total tokens)
        n_mask_tokens = max(1, round(len(ids) * 0.15))

        # Identify the index of the [SEP] token
        sep_index = (ids == tokenizer.token_to_id("[SEP]")).argmax()

        # Create a list of candidate tokens for masking, excluding special tokens
        candidate_mask = np.arange(len(ids))
        candidate_mask = candidate_mask[~np.isin(candidate_mask, [0, sep_index])]

        # Randomly select tokens to mask
        selected_tokens = np.random.choice(candidate_mask, size=n_mask_tokens, replace=False)
        mlm_labels = ids.copy()

        # Replace selected tokens with the [MASK] token ID
        ids[selected_tokens] = tokenizer.token_to_id("[MASK]")

        return ids.tolist(), mlm_labels.tolist()

In [None]:
class DataCollatorForMLMAndNSP:
    """
    A data collator class for preparing batches of data for Masked Language Modeling (MLM)
    and Next Sentence Prediction (NSP).

    Attributes:
        pad_token_id (int): The token ID used for padding sequences.
    """

    def __init__(self, pad_token_id=0) -> None:
        """
        Initializes the DataCollatorForMLMAndNSP with the specified padding token ID.

        Args:
            pad_token_id (int): The token ID to use for padding sequences. Default is 0.
        """
        self.pad_token_id = pad_token_id

    def __call__(self, sentences):
        """
        Processes a batch of sentences to generate padded token IDs, attention masks, MLM labels,
        and NSP labels.

        Args:
            sentences (list of tuples): A list where each tuple contains token IDs, MLM labels,
                                        and an NSP label.

        Returns:
            tuple: A tuple containing the padded token IDs, attention masks, MLM labels, and NSP labels.
        """
        # Convert lists of token IDs and MLM labels to tensors
        ids = [torch.LongTensor(tokens[0]) for tokens in sentences]
        mlm_labels = [torch.LongTensor(tokens[1]) for tokens in sentences]
        nsp_labels = torch.LongTensor([[tokens[2]] for tokens in sentences]).squeeze(1)

        # Pad the sequences to have the same length
        token_ids = pad_sequence(ids, padding_value=self.pad_token_id, batch_first=True)
        mlm_labels = pad_sequence(mlm_labels, padding_value=self.pad_token_id, batch_first=True)

        # Create attention masks (1 for real tokens, 0 for padding tokens)
        attention_masks = torch.ne(token_ids, self.pad_token_id).long()
        attention_masks = attention_masks.unsqueeze(1).unsqueeze(2)

        return token_ids, attention_masks, mlm_labels, nsp_labels

In [None]:
ds = CustomDataset(data, False)
data_loader = DataLoader(ds, batch_size=2, collate_fn=DataCollatorForMLMAndNSP(pad_token_id=0))

In [None]:
tokens, attention_mask, mlm_labels, nsp_labels = next(iter(data_loader))

In [None]:
tokens

In [None]:
attention_mask

In [None]:
mlm_labels

In [None]:
nsp_labels

In [None]:
tokenizer.decode(tokens[0].tolist(), skip_special_tokens=False)

## BERT Model implementation

In [None]:
from utils import EncoderTransformer
from tqdm.notebook import tqdm
from sklearn.metrics import f1_score
import torch

In [None]:
class BERT(torch.nn.Module):
    """
    A BERT model class for handling masked language modeling (MLM) and next sentence prediction (NSP).

    Attributes:
        encoder (EncoderTransformer): The encoder transformer network.
        nsp_dnn (torch.nn.Linear): The linear layer for NSP.
        mlm_dnn (torch.nn.Linear): The linear layer for MLM.
        device (str): The device to run the model on.
    """

    def __init__(self, embed_dim, num_heads, dropout, pf_dim, vocab_size, max_length, n_layers, device) -> None:
        """
        Initializes the BERT model with the specified parameters.

        Args:
            embed_dim (int): The embedding dimension.
            num_heads (int): The number of attention heads.
            dropout (float): The dropout rate.
            pf_dim (int): The position-wise feed-forward dimension.
            vocab_size (int): The size of the vocabulary.
            max_length (int): The maximum sequence length.
            n_layers (int): The number of layers in the transformer.
            device (str): The device to run the model on.
        """
        super().__init__()

        self.encoder = EncoderTransformer(embed_dim, num_heads, dropout, pf_dim, vocab_size, max_length, n_layers, device=device)
        self.nsp_dnn = torch.nn.Linear(embed_dim, 2, device=device)
        self.mlm_dnn = torch.nn.Linear(embed_dim, vocab_size, device=device)

        self.device = device

    def config_training_args(self, optimizer, optimizer_kwargs={}) -> None:
        """
        Configures the optimizer for training.

        Args:
            optimizer (torch.optim.Optimizer): The optimizer class.
            optimizer_kwargs (dict): Additional keyword arguments for the optimizer.
        """
        self.optimizer = optimizer(self.parameters(), **optimizer_kwargs)

    def forward(self, x, mask):
        """
        Forward pass through the BERT model.

        Args:
            x (torch.Tensor): The input tensor containing token IDs.
            mask (torch.Tensor): The attention mask tensor.

        Returns:
            tuple: A tuple containing the MLM logits and NSP logits.
        """
        import torch

        # Synchronize CUDA
        torch.cuda.synchronize()

        # Pass through the encoder
        x = self.encoder(x, mask)

        # Compute MLM and NSP predictions
        y_mlm = self.mlm_dnn(x)
        y_nsp = self.nsp_dnn(x[:, 0, :])  # Only use the [CLS] token for NSP

        return y_mlm, y_nsp

    def train_one_epoch(self, train_loader) -> None:
        """
        Trains the model for one epoch.

        Args:
            train_loader (DataLoader): The DataLoader for the training data.
        """
        bar = tqdm(train_loader)
        running_total_loss = 0
        running_mlm_loss = 0
        running_nsp_loss = 0
        running_f1 = 0

        for step, (ids, attention_mask, mlm_labels, nsp_labels) in enumerate(bar, 1):

            # Map to device
            ids = ids.to(self.device)
            attention_mask = attention_mask.to(self.device)
            mlm_labels = mlm_labels.to(self.device)
            nsp_labels = nsp_labels.to(self.device)

            # Forward pass
            y_mlm, y_nsp = self(ids, attention_mask)
            y_mlm = y_mlm.reshape(-1, y_mlm.shape[-1])
            mlm_labels = mlm_labels.reshape(-1)

            # Compute MLM loss
            mlm_loss = torch.nn.functional.cross_entropy(y_mlm, mlm_labels, ignore_index=0)

            # Compute NSP loss
            nsp_loss = torch.nn.functional.cross_entropy(y_nsp, nsp_labels)

            # Total loss
            loss = nsp_loss + mlm_loss

            # Clear gradients
            self.optimizer.zero_grad()

            # Backward pass
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)

            # Optimization step
            self.optimizer.step()

            # Running total losses
            running_mlm_loss += mlm_loss.item()
            running_nsp_loss += nsp_loss.item()
            running_total_loss += loss.item()

            # Running F1 score
            y_nsp = torch.argmax(y_nsp, dim=-1)
            running_f1 += f1_score(nsp_labels.cpu(), y_nsp.cpu(), average="macro")

            # Update progress bar description
            bar.set_description(
                f"total loss: {running_total_loss / step:0.3f} | "
                + f"mlm loss: {running_mlm_loss / step:0.3f} | "
                + f"nsp loss: {running_nsp_loss / step:0.3f} | "
                + f"f1: {running_f1 / step:0.3f}"
            )

    def train(self, train_data, num_epochs) -> None:
        """
        Trains the model for a specified number of epochs.

        Args:
            train_data (DataLoader): The DataLoader for the training data.
            num_epochs (int): The number of epochs to train for.
        """
        bar = tqdm(range(num_epochs))

        for epoch in bar:
            self.train_one_epoch(train_data)
            bar.set_description(f"epoch: {epoch}")

    def mask_filling(self, sentence: str, tokenizer) -> str:
        """
        Fills masked tokens in a given sentence.

        Args:
            sentence (str): The input sentence with masked tokens.
            tokenizer: The tokenizer used to encode and decode the sentence.

        Returns:
            str: The sentence with masked tokens filled.
        """
        sentence = "[CLS] " + sentence + " [SEP]"
        token_ids = torch.LongTensor([tokenizer.encode(sentence).ids]).to(self.device)
        attention_mask = torch.ne(token_ids, tokenizer.encode("[PAD]").ids[0]).long().to(self.device)

        with torch.no_grad():
            y_mlm, _ = self(token_ids, attention_mask)
            y_mlm = y_mlm.argmax(dim=-1)
            sentence = tokenizer.decode(y_mlm[0].tolist())

            return sentence

    def nsp_prediction(self, sentence_a: str, sentence_b: str, tokenizer) -> int:
        """
        Predicts whether the second sentence is the next sentence of the first.

        Args:
            sentence_a (str): The first sentence.
            sentence_b (str): The second sentence.
            tokenizer: The tokenizer used to encode the sentences.

        Returns:
            int: The NSP prediction (0 or 1).
        """
        sentence = "[CLS] " + sentence_a + " [SEP] " + sentence_b
        token_ids = torch.LongTensor([tokenizer.encode(sentence).ids]).to(self.device)
        attention_mask = torch.ne(token_ids, tokenizer.encode("[PAD]").ids[0]).long().to(self.device)

        with torch.no_grad():
            _, y_nsp = self(token_ids, attention_mask)
            y_nsp = torch.argmax(y_nsp, dim=-1)

        return y_nsp.item()

In [None]:
ds = CustomDataset(data, False)
data_loader = DataLoader(ds, batch_size=32, collate_fn=DataCollatorForMLMAndNSP(pad_token_id=0), num_workers=5, prefetch_factor=5)

In [None]:
# define model parameters
embed_dim = 768
max_length = 128
num_heads = 8
vocab_size = 30000
n_layers = 12
dropout = 0.3
pf_dim = 512

In [None]:
bert_model = BERT(embed_dim, num_heads, dropout, pf_dim, vocab_size, max_length, n_layers, device="cuda")

In [None]:
# config model
optimizer = torch.optim.Adam

bert_model.config_training_args(optimizer=optimizer, optimizer_kwargs={"lr": 2e-5})

In [None]:
bert_model.train(data_loader, num_epochs=20)

In [None]:
bert_model.mask_filling(
    "humans lived in societies without formal hierarchies long before the [MASK] of formal states, realms, or [MASK]",
    tokenizer,
)

In [None]:
bert_model.nsp_prediction(
    sentence_a="humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires",
    sentence_b="domesticated almonds appear in the early bronze age (bc), such as the archaeological sites of numeira (jordan), or possibly earlier",
    tokenizer=tokenizer,
)