In [None]:
!pip install --upgrade datasets --quiet

In [None]:
from google.colab import drive
import os

drive.mount("/content/drive")
%cd /content/drive/MyDrive

PARENT_DIR = os.getcwd()
MAIN_DIR = os.path.join(
    PARENT_DIR, "Implement Classic Papers", "BERT"
)
DATA_DIR = os.path.join(
    PARENT_DIR, "Data", "Text", "bookcorpus"
)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive


In [None]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import random
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

from typing import List, Union, Callable, Optional, Dict
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

MAIN_DIR = os.getcwd()
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    # os.path.join(MAIN_DIR, "tokenizers", "google-bert", "bert-large-uncased")
    "bert-large-uncased"
    )

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



# Pretraining

- Masked Language Modelling (MLM)
- Next Sentence Prediction (NSP)

## Architecture

### Embeddings
- Token embeddings: For each token in vocabulary
- Absolute positional embedding: Using sinusoidal function to model absolute position in the sentence sequence.
- Segment Embedding: Used for next sentence prediction and NLI tasks

In [None]:
class SinusoidalPositionalEncoding(nn.Module):
    def __init__(
        self, d_model: int, dropout: float=0.1, max_len: int=512
    ):
        super(SinusoidalPositionalEncoding, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(size=(max_len, d_model)) # (max_len, d_model)
        const_term = math.log(10000) / d_model
        div_terms = torch.exp(-torch.arange(0, d_model, 2) * const_term) # (d_model//2)
        positions = torch.arange(0, max_len).unsqueeze(1) # (max_len, 1)
        pe[:, ::2] = torch.sin(positions*div_terms)  # sin(pos * div_term)
        pe[:, 1::2] = torch.cos(positions*div_terms)  # sin(pos * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(
        self, inputs: torch.Tensor
    ):
        # Input sequence: (batch_size, seq_length)
        seq_len = inputs.size(1)
        x = self.pe[:, : seq_len, :].requires_grad_(False)
        return self.dropout(x)

class TokenEmbeddings(nn.Module):
    def __init__(
        self,
        d_model: int,
        vocab_size: int=30522,
        padding_idx: int=0
    ):
        super(TokenEmbeddings, self).__init__()
        self.d_model, self.vocab_size = d_model, vocab_size
        self.embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model, padding_idx=padding_idx
        )

    def forward(
        self, inputs: torch.Tensor
    ):
        return self.embeddings(inputs) * math.sqrt(self.d_model)

class SegmentEmbedding(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        padding_idx: int = 2,
    ):
        super(SegmentEmbedding, self).__init__()
        self.padding_idx = padding_idx
        self.embedding = nn.Embedding(
                3, embedding_dim=d_model, padding_idx=padding_idx
            )

    def forward(
        self,
        token_types_id: torch.Tensor, # (batch_size, seq_len)
        attn_mask: Optional[torch.Tensor] = None
        ):
        token_types_id = token_types_id.masked_fill(
            attn_mask==0, self.padding_idx
        )
        segment_embeddings = self.embedding(token_types_id)
        return segment_embeddings # (batch_size, seq_len, d_model)

class BERTEmbedding(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        vocab_size: int = 30522,
        use_segment_embedding: bool = False,
        dropout: float = 0.1,
        token_padding_idx: int = 0,
        segment_padding_idx: int = 2
    ):
        super(BERTEmbedding, self).__init__()
        self.token_embedding = TokenEmbeddings(
            d_model=d_model,
            vocab_size=vocab_size,
            padding_idx=token_padding_idx
            )
        self.positional_embedding = SinusoidalPositionalEncoding(
            d_model=d_model, max_len=512
        )
        self.segment_embedding = None
        if use_segment_embedding:
            self.segment_embedding = SegmentEmbedding(
                d_model=d_model, padding_idx=segment_padding_idx
            )
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self,
        inputs: torch.Tensor, # (batch_size, seq_len)
        token_type_ids: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        token_embeddings = self.token_embedding(inputs)
        pos_embeddings = self.positional_embedding(inputs)
        embeddings = token_embeddings + pos_embeddings
        if self.segment_embedding:
            embeddings = embeddings + self.segment_embedding(token_type_ids, attn_mask)
        return self.dropout(embeddings)

### Model Architecture

In [None]:
class GroupAttentionHead(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_q: int,
        d_k: int,
        d_v: int,
        n_heads: int = 1,
        dropout: float = 0.1
    ):
        super(GroupAttentionHead, self).__init__()
        self.d_q, self.d_k, self.d_v = d_q, d_k, d_v
        self.Q = nn.ModuleList([nn.Linear(d_model, d_q) for _ in range(n_heads)])
        self.K = nn.Linear(d_model, d_k)
        self.V = nn.Linear(d_model, d_v)
        self.dropout = dropout

    def forward(
        self,
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
        attn_mask: Optional[torch.Tensor]=None,
        dropout: Optional[float]=None,
        is_causal: bool=False
    ):
        """Calculate attention values for a single attention head.
        For self attention, q = k = v; Dimension = (batch_size, seq_len, d_model)
        For seq2seq cross attention
        - q = output of masked attention previous component of the decoder. Dimension = (batch_size, tgt_seq_len, d_model)
        - k = v = output of the encoder. Dimension = (batch_size, src_seq_len, d_model)
        """
        dropout = dropout or self.dropout
        query = torch.cat(
            [q(query).unsqueeze(1) for q in self.Q], dim = 1
        ) # (batch_size, n_groups, seq_len, d_q)
        key = self.K(key).unsqueeze(1) # (batch_size, 1, seq_len, d_k)
        value = self.V(value).unsqueeze(1) # (batch_size, 1, seq_len, d_v)
        if isinstance(attn_mask, torch.Tensor):
            attn_mask = attn_mask.unsqueeze(1).unsqueeze(1).type(torch.bool)

        values = F.scaled_dot_product_attention(
            query, key, value, attn_mask, dropout, is_causal
            )

        return values

class MultiHeadedAttention(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        n_heads: int = 12,
        n_groups: int = 12,
        dropout: float = 0.1,
        d_v: Optional[int] = None
    ):
        super(MultiHeadedAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.n_groups = n_groups
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        assert n_heads % n_groups == 0, "Total number of heads must be divisible by number of groups"
        self.heads_per_group = n_heads // n_groups
        self.d_q = d_model // n_heads
        self.d_k = self.d_q
        self.d_v = d_v or self.d_q
        self.attn_heads = nn.ModuleList(
            [
                GroupAttentionHead(d_model, self.d_q, self.d_k, self.d_v, n_heads=self.heads_per_group, dropout=dropout)
                for _ in range(n_groups)
                ]
            )
        self.O = nn.Linear(self.d_v * self.n_heads, self.d_model)

    def forward(
        self,
        query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
        attn_mask: Optional[torch.Tensor]=None,
        dropout: Optional[float]=None,
        is_causal: bool=False
    ):
        batch_size = query.size(0)
        tgt_seq_len = value.size(1)
        values = torch.cat(
            [
                att_head(
                    query, key, value,
                    attn_mask=attn_mask, dropout=dropout, is_causal=is_causal)
                for att_head in self.attn_heads
                ], dim = 1
            ).transpose(1, 2) # (bs, tgt_seq_len, n_heads * n_groups, d_v)
        values = values.contiguous().view(batch_size, tgt_seq_len, -1)
        return self.O(values)

class PointwiseFeedForward(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
    ):
        super(PointwiseFeedForward, self).__init__()
        self.d_feedforward = d_feedfoward or d_model * 4
        self.linear1 = nn.Linear(d_model, self.d_feedforward)
        self.linear2 = nn.Linear(self.d_feedforward, d_model)
        self.activation = getattr(nn.functional, activation)

    def forward(
        self, inputs: torch.Tensor
    ):
        x = self.linear1(inputs)
        x = self.linear2(self.activation(x))
        return x

class ResidualLayer(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        dropout: float = 0.1
    ):
        super(ResidualLayer, self).__init__()
        self.d_model = d_model
        self.layer_norm = nn.LayerNorm(normalized_shape=d_model, eps=1e-5)
        self.dropout = nn.Dropout(p=dropout)

    def forward(
        self, inputs: torch.Tensor, sublayer: Callable
    ):
        return self.layer_norm(inputs + self.dropout(sublayer(inputs)))

class EncoderBlock(nn.Module):
    def __init__(
        self,
        d_model: int=768,
        n_heads: int=12,
        n_groups: int=3,
        dropout: float=0.1,
        d_v: Optional[int] = None,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
    ):
        super(EncoderBlock, self).__init__()
        self.d_model = d_model
        self.multihead_group_attention = MultiHeadedAttention(
            d_model=d_model, n_heads=n_heads, n_groups=n_groups, dropout=dropout, d_v=d_v,
        )
        self.feedforward = PointwiseFeedForward(d_model=d_model, d_feedfoward=d_feedfoward, activation=activation)
        self.att_residual = ResidualLayer(d_model=d_model)
        self.ff_residual = ResidualLayer(d_model=d_model)

    def forward(
        self,
        inputs: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False
    ):
        x = self.att_residual(
            inputs,
            lambda x: self.multihead_group_attention(
                x, x, x, attn_mask=attn_mask, is_causal=is_causal
            )
        )
        x = self.ff_residual(x, self.feedforward)
        return x

class TransformerEncoder(nn.Module):
    def __init__(
        self,
        d_model: int=768,
        n_layers: int=12,
        n_heads: int=12,
        n_groups: int=3,
        dropout: float=0.1,
        d_v: Optional[int] = None,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
    ):
        super(TransformerEncoder, self).__init__()
        self.d_model = d_model
        self.layers = nn.ModuleList(
            [
                EncoderBlock(
                    d_model=d_model, n_heads=n_heads, n_groups=n_groups, dropout=dropout, d_v=d_v,
                    d_feedfoward=d_feedfoward, activation=activation)
                for _ in range(n_layers)
                ]
        )

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False
    ):
        for layer in self.layers:
            x = layer(x, attn_mask=attn_mask, is_causal=is_causal)
        return x

class BERTBackbone(nn.Module):
    def __init__(
        self,
        d_model: int=768,
        vocab_size: int=30522,
        n_layers: int=12,
        n_heads: int=12,
        n_groups: int=3,
        dropout: float=0.1,
        d_v: Optional[int] = None,
        d_feedfoward: Optional[int] = None,
        activation: str = "gelu",
        use_segment_embedding: bool = False
    ):
        super(BERTBackbone, self).__init__()

        self.embedding = BERTEmbedding(
            d_model=d_model,
            vocab_size=vocab_size,
            use_segment_embedding=use_segment_embedding,
            dropout=dropout
        )

        self.encoder = TransformerEncoder(
            d_model=d_model,
            n_layers=n_layers,
            n_heads=n_heads,
            n_groups=n_groups,
            dropout=dropout,
            d_v=d_v,
            d_feedfoward=d_feedfoward,
            activation=activation
        )

        self.d_model = d_model
        self.vocab_size = vocab_size

        # Initialize weights
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(
        self,
        inputs: torch.Tensor, # (batch_size, seq_len)
        token_type_ids: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False
    ):
        x = self.embedding(
            inputs=inputs, token_type_ids=token_type_ids, attn_mask=attn_mask
        )
        x = self.encoder(
            x,
            attn_mask=attn_mask.type(torch.float),
            is_causal=is_causal
            )
        return x

class MaskLanguageHead(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        vocab_size: int = 30522
    ):
        super(MaskLanguageHead, self).__init__()
        self.d_model = d_model
        self.proj = nn.Linear(
            d_model, vocab_size
        )

    def forward(
        self,
        inputs: torch.Tensor
    ):
        return F.log_softmax(self.proj(inputs), dim=-1)

class NextSentencePredictionHead(nn.Module):
    def __init__(
        self,
        d_model: int = 768
    ):
        super(NextSentencePredictionHead, self).__init__()
        self.proj = nn.Linear(d_model, 2)

    def forward(
        self,
        inputs: torch.Tensor
    ):
        cls_embs = inputs[:, 0, :] # (batch_size, d_model) -> (batch_size, 2)
        return F.log_softmax(self.proj(cls_embs), dim=-1)

class BERTForPretraining(nn.Module):
    def __init__(
        self,
        bert: BERTBackbone,
        vocab_size: Optional[int] = None
    ):
        super(BERTForPretraining, self).__init__()
        self.bert = bert
        self.vocab_size = vocab_size or self.bert.vocab_size
        self.d_model = self.bert.d_model
        self.mlm = MaskLanguageHead(
            self.d_model, self.vocab_size
        )
        self.nsp = NextSentencePredictionHead(
            self.d_model
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        **kwargs,
    ):
        final_hiddens = self.bert(
            inputs = input_ids,
            token_type_ids = token_type_ids,
            attn_mask = attn_mask,
            is_causal = is_causal
        )
        return (self.mlm(final_hiddens), self.nsp(final_hiddens))

## Data Preparation

For BERT, pretraining used BookCorpus (800M tokens) and Wikipedia (2.5B tokens). Here we only used a subset of BookCorpus.

- Bookcorpus
- NSP: 50% samples use next text spans, 50% samples used randomly sampled text spans
- MLM: 15% of content tokens are masked
    - 80% with [MASK] tokens
    - 10% with a random tokens
    - 10% unchanged

In [None]:
class BookCorpusPretrainDataset(Dataset):
    def __init__(
        self,
        corpus: List[str],
        tokenizer: Callable,
        corpus_len: Optional[List[int]] = None,
        seq_len: int = 128,
        prob: float = 0.15,
        mask_prob: float = 0.80,
        random_prob: float = 0.10,
        **kwargs
    ):
        self.corpus = corpus
        self.corpus_len = corpus_len
        if not self.corpus_len:
            batch_size = 256
            self.corpus_len = []

            for start_idx in tqdm(range(0, len(bookcorpus), batch_size)):
                batch_data = tokenizer(self.corpus[start_idx:start_idx+batch_size]["text"], add_special_tokens=False)["input_ids"]
                for seq in batch_data:
                    self.corpus_len.append(len(seq))

        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.vocab_size = tokenizer.vocab_size
        self.special_token_map = {}
        for token_type, token_str in self.tokenizer.special_tokens_map.items():
            token_id = self.tokenizer.convert_tokens_to_ids(token_str)
            self.special_token_map[token_type] = (token_str, token_id)

        self.non_special_token_ids = [token_id for token_id in range(self.vocab_size) if token_id not in self.tokenizer.all_special_ids]
        self.prob = prob
        self.mask_prob = mask_prob
        self.random_prob = random_prob

    def get_text_span(self, idx, max_len: int):
        sentences = []
        token_counts = 0
        if self.corpus_len[idx] > max_len:
            return sentences

        sentences.append(self.corpus[idx])
        token_counts += self.corpus_len[idx]

        curr_idx = idx
        while token_counts < max_len and curr_idx < (len(self) - 1):
            next_sentence = self.corpus[curr_idx + 1]
            curr_idx += 1
            if token_counts + self.corpus_len[curr_idx] <= max_len:
                sentences.append(next_sentence)
                token_counts += self.corpus_len[curr_idx]
            else:
                break

        return sentences

    def get_seq_len(self, input_str):
        return len(self.tokenizer.encode(input_str, add_special_tokens=False))

    def __len__(self):
        return len(self.corpus)

    def __sample_mask(
        self,
        input_ids: List[int],
        prob: Optional[float] = None,
        mask_prob: Optional[float] = None,
        random_prob: Optional[float] = None,
    ):
        prob = prob or self.prob
        mask_prob = mask_prob or self.mask_prob
        random_prob = random_prob or self.random_prob

        masked_seq = []
        masked_labels = []
        for token_id in input_ids:
            if token_id not in self.tokenizer.all_special_ids:
                if random.random() < 0.15:
                    prob = random.random()
                    if prob < mask_prob: # Change to MASK token (80%)
                        masked_seq.append(self.special_token_map["mask_token"][1])
                    elif prob < (mask_prob + random_prob): # Change to a random token (10%)
                        masked_seq.append(random.choice(self.non_special_token_ids))
                    else: # Does not change (10%)
                        masked_seq.append(token_id)
                    masked_labels.append(token_id)
                else:
                    masked_seq.append(token_id)
                    masked_labels.append(self.special_token_map["pad_token"][1])
            else:
                masked_seq.append(token_id)
                masked_labels.append(self.special_token_map["pad_token"][1])
        return (masked_seq, masked_labels)

    def __getitem__(self, idx):
        if idx == (len(self) - 1):
            idx = random.randint(0, len(self) - 2)
        while True:
            sentences = self.get_text_span(idx, max_len=self.seq_len - 3)
            if len(sentences) >= 2:
                break
            else:
                idx = random.randint(0, len(self) - 2)

        max_len = random.randint(1, len(sentences) - 1)

        split_threshold = random.randint(1, max_len)
        nsp_prob = random.random()
        sent_A = " ".join(sentences[:split_threshold])
        if nsp_prob > 0.5:
            sent_B = " ".join(sentences[split_threshold:])
            nsp_label = 0
        else: # Need to sample a random span
            while True:
                random_span = self.get_text_span(random.randint(0, len(self)-1), max_len=self.seq_len - self.get_seq_len(sent_A) - 3)
                if len(random_span) >= 2:
                    break
                else:
                    split_threshold -= 1
                    sent_A = " ".join(sentences[:split_threshold])

            random_split_threshold = random.randint(1, len(random_span) - 1)
            sent_B = " ".join(random_span[random_split_threshold:])
            nsp_label = 1

        tokens_info = self.tokenizer(
            sent_A, sent_B, truncation=False, padding='max_length', max_length=self.seq_len
        )
        input_ids = tokens_info["input_ids"]
        attn_mask = tokens_info["attention_mask"]
        token_type_ids = tokens_info["token_type_ids"]

        input_ids, labels = self.__sample_mask(input_ids=input_ids)
        return {
            "input_ids": torch.tensor(input_ids),
            "attn_mask": torch.tensor(attn_mask),
            "token_type_ids": torch.tensor(token_type_ids),
            "mlm_label": torch.tensor(labels),
            "nsp_label": torch.tensor([nsp_label]),
        }

## Trainer Class

In [None]:
def custom_lr_schedule(
    step_no: int, warm_up: int = 10000, total_steps: int = 100000
):
    post_warm_up_steps = total_steps - warm_up
    if step_no == 0:
        step_no = 1
    if step_no <= warm_up:
        return step_no / warm_up
    else:
        return (total_steps - step_no) / post_warm_up_steps

class TrainState:
    n_steps: int = 0
    n_updates: int = 0
    train_loss = {
        "mlm": {"step": [], "epoch": []},
        "nsp": {"step": [], "epoch": []}
        }
    val_loss = {
        "mlm": {"step": [], "epoch": []},
        "nsp": {"step": [], "epoch": []}
        }
    lr = []

class BERTTrainer:
    def __init__(
        self,
        model: nn.Module,
        optimizer_args: Dict,
        sampling_args: Dict,
        device: str,
    ):
        self.model = model.to(device)
        self.optimizer_args = optimizer_args
        self.sampling_args = sampling_args
        self.device = device
        self.train_state = TrainState()
        self.accumulation_steps = optimizer_args["accumulation_steps"]

        self.optimizer = AdamW(
            params = self.model.parameters(),
            lr = optimizer_args["lr"],
            betas = optimizer_args["betas"],
            weight_decay = optimizer_args["weight_decay"]
        )

        self.scheduler = LambdaLR(
            optimizer=self.optimizer,
            lr_lambda=lambda x: custom_lr_schedule(
                x, warm_up=optimizer_args["warmup"], total_steps=optimizer_args["total_steps"]
                )
            )
        self.loss_criterion = nn.NLLLoss(ignore_index=0)
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def run_epoch(
        self,
        epoch: int,
        dataloader: DataLoader,
        train_state: Optional[TrainState] = None,
        max_steps: Optional[int] = None,
        is_train: bool = True
    ):
        train_state = train_state or self.train_state
        if is_train:
            self.model.train()
        else:
            self.model.eval()

        mode = "train" if is_train else "eval"

        data_iter = tqdm(
            enumerate(dataloader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(dataloader),
            bar_format="{l_bar}{r_bar}"
        )

        epoch_mlm_loss = 0.0
        epoch_nsp_loss = 0.0
        epoch_nsp_preds = 0
        epoch_nsp_correct = 0
        epoch_steps = 0

        for idx, batch in data_iter:
            if train_state.n_steps >= max_steps:
                break

            batch = {k: v.to(self.device) for k, v in batch.items()}
            input_ids = batch["input_ids"]
            token_type_ids = batch["token_type_ids"]
            attn_mask = batch["attn_mask"]
            mlm_labels = batch["mlm_label"]
            nsp_labels = batch["nsp_label"]

            mlm_outs, nsp_outs = self.model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attn_mask=attn_mask,
            )

            mlm_loss = self.loss_criterion(mlm_outs.view(-1, mlm_outs.shape[-1]), mlm_labels.view(-1))
            nsp_loss = self.loss_criterion(nsp_outs, nsp_labels.view(-1))
            loss = mlm_loss + nsp_loss

            nsp_preds = nsp_outs.argmax(dim=-1)
            nsp_correct = nsp_preds.eq(nsp_labels.view(-1)).sum().item()

            epoch_nsp_preds += nsp_preds.size(0)
            epoch_nsp_correct += nsp_correct

            if is_train:
                self.train_state.n_steps += 1
                loss.backward()
                if self.train_state.n_steps % self.accumulation_steps == 0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    self.train_state.n_updates += 1

                lr = self.optimizer.param_groups[0]["lr"]
                self.train_state.lr.append(lr)
                self.train_state.train_loss["mlm"]["step"].append(mlm_loss.item())
                self.train_state.train_loss["nsp"]["step"].append(mlm_loss.item())
                self.scheduler.step()

            else:
                self.train_state.val_loss["mlm"]["step"].append(mlm_loss.item())
                self.train_state.val_loss["nsp"]["step"].append(mlm_loss.item())

            epoch_mlm_loss += mlm_loss.item()
            epoch_nsp_loss += nsp_loss.item()
            epoch_steps += 1

            if (idx + 1) % 100 == 0:
                avg_mlm_loss = epoch_mlm_loss / epoch_steps
                avg_nsp_loss = epoch_nsp_loss / epoch_steps
                avg_nsp_acc = epoch_nsp_correct / epoch_nsp_preds

                print(
                    f"Epoch {epoch} - Step {self.train_state.n_steps}, mode {mode}: avg_mlm_loss={avg_mlm_loss}, avg_nsp_loss={avg_nsp_loss}, avg_nsp_acc={avg_nsp_acc}")

        avg_mlm_loss = epoch_mlm_loss / epoch_steps
        avg_nsp_loss = epoch_nsp_loss / epoch_steps
        avg_nsp_acc = epoch_nsp_correct / epoch_nsp_preds

        if is_train:
            train_state.train_loss["mlm"]["epoch"].append(avg_mlm_loss)
            train_state.train_loss["nsp"]["epoch"].append(avg_nsp_loss)
        else:
            train_state.val_loss["mlm"]["epoch"].append(avg_mlm_loss)
            train_state.val_loss["nsp"]["epoch"].append(avg_nsp_loss)

        print(
            f"Whole Epoch {epoch}, mode {mode}: avg_mlm_loss={avg_mlm_loss}, avg_nsp_loss={avg_nsp_loss}, avg_nsp_acc={avg_nsp_acc}")

    def train(
        self,
        n_step: int,
        dataloader: DataLoader,
        max_steps: Optional[int] = None,
        is_train: bool = True
        ):
        max_steps = self.train_state.n_steps + n_step
        epoch_no = 0
        while self.train_state.n_steps < max_steps:
            epoch_no += 1
            self.run_epoch(
                epoch=epoch_no,
                dataloader=dataloader,
                train_state=self.train_state,
                max_steps=max_steps,
                is_train=is_train
            )

## Pretrain !!!

### Pretrain Arguments

- Using a linear LR scheduler for warm-up (10000 steps) and decay (LR linearly reduced to 0 from max LR).
- 2 stage:
    - Stage 1: 90% steps - Use 128 sequence length
    - Stage 2: 10% steps - Use 512 sequence length

In [None]:
# model_args = dict(
#     d_model=32,
#     n_layers=2,
#     n_heads=4,
#     n_groups=1,
#     dropout=0.1,
#     use_segment_embedding=True
# ) # DEBUG

model_args = dict(
    d_model=768,
    n_layers=12,
    n_heads=12,
    n_groups=4,
    dropout=0.1,
    use_segment_embedding=True
) # BERT BASE

optimizer_args = dict(
    lr=1e-4,
    betas=(0.9, 0.999),
    weight_decay=0.01,
    warmup=10000,
    total_steps=100000,
    accumulation_steps=1
)

sampling_args = dict(
    prob=0.15,
    mask_prob=0.80,
    random_prob=0.10,
)

training_stages = [
    dict(n_steps=90000, batch_size=32, seq_len=128), # Stage 1
    dict(n_steps=10000, batch_size=32, seq_len=512) # Stage 2
    ]

In [None]:
bookcorpus_dataset = load_dataset(os.path.join(DATA_DIR))["train"]

bookcorpus = bookcorpus_dataset[:100000]["text"]

batch_size = 256
corpus_len = []

for start_idx in tqdm(range(0, len(bookcorpus), batch_size)):
    batch_data = tokenizer(bookcorpus[start_idx:start_idx+batch_size], add_special_tokens=False)["input_ids"]
    for seq in batch_data:
        corpus_len.append(len(seq))

100%|██████████| 391/391 [00:05<00:00, 68.74it/s]


In [None]:
bert_backbone = BERTBackbone(**model_args)
bert_pretrain_model = BERTForPretraining(
    bert=bert_backbone,
)
model = bert_pretrain_model
bert_trainer = BERTTrainer(
    model=model,
    optimizer_args=optimizer_args,
    sampling_args=sampling_args,
    device=DEVICE,
)

Total Parameters: 2006332


In [None]:
for stage in training_stages:
    bert_pretrain_dataset = BookCorpusPretrainDataset(
        corpus=bookcorpus,
        tokenizer=tokenizer,
        corpus_len=corpus_len,
        seq_len = stage["seq_len"],
        **sampling_args
    )
    bert_pretrain_dataloader = DataLoader(
        dataset=bert_pretrain_dataset,
        batch_size=stage["batch_size"],
        shuffle=True,
        num_workers=8,
        pin_memory=True,
        drop_last=True
    )
    torch.cuda.empty_cache()
    bert_trainer.train(
        n_step=stage["n_steps"],
        dataloader=bert_pretrain_dataloader,
        is_train=True
    )

In [None]:
output_dir = os.path.join(
    MAIN_DIR, "artifacts", "checkpoints"
)
os.makedirs(output_dir, exist_ok=True)
torch.save(model.bert.state_dict(), os.path.join(output_dir, "pretrain_checkpoint.pt"))