# Exercise 5: Debugging a BERT Model Training Script 🐞

In this exercise, you will debug a script designed to train a [BERT](https://arxiv.org/abs/1810.04805) model using a portion of the [wikitext](https://huggingface.co/datasets/wikitext) dataset. We don't expect you to find any bug in the code provided in `utils`.

🔍 **Mission:** Identify and correct at least 8 issues (🐞) in this notebook. Found more? Awesome – tell us about them!😄

### Debugging Tips:

**Memory Usage:** Identify which parameters impact the model's memory usage. The finalized script should run on Colab using a free account.

**Input/Output Dimensions:** How are the input and output dimensions defined in a transformer?

**Training Metrics:** Is the model learning? Why not?

**Evaluation Metrics:** Are they behaving as expected?

**Start Small:** Use a smaller dataset initially to identify and solve issues faster and without using extensive computational resources. Two script-running options are provided to help you. Running 200 epochs on the large dataset will take more than an hour. No need to do it – but if you are up to, go for it!

**Performance Indicator:** A training perplexity around 20 after 50 epochs in `DEBUG` mode means you are on the right path.

Happy debugging! 🚀

In [1]:
# ENABLE IF RUNNING ON GOOGLE COLAB

# !pip install transformers
# !pip install datasets

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertConfig, DataCollatorForLanguageModeling
from datasets import load_dataset

from utils import visualization, models

import math
import random
import numpy as np

In [3]:
# Set seeds
def reset_seed():
  seed = 5
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps" if torch.has_mps else "cpu")
print(f"Using device: {device}")

mode_dropdown = visualization.init()

Using device: mps


Dropdown(description='Mode:', options=('DEBUG', 'RUN'), value='DEBUG')

In [4]:
from typing import Optional

class Ex5BertEmbeddings(nn.Module):
    """Construct the embeddings from learnable word and position embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id, device=config.device)
        self.positional_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size, device=config.device)

        self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Compute possible positions
        self.positions= torch.arange(config.max_position_embeddings, dtype=torch.long, device=config.device)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None
    ) -> torch.Tensor:
        # Compute word embeddings
        word_embeddings = self.word_embeddings(input_ids)
        
        # Compute positional embeddings
        position_ids = torch.cat([self.positions] * input_ids.shape[0]).reshape(input_ids.shape)
        pos_embeddings = self.positional_embeddings(position_ids)

        # Sum word and positional embeddings
        embeddings = word_embeddings + pos_embeddings

        embeddings = self.norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [5]:
class Ex5BertSelfAttentionLinearOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.intermediate_size)
        self.norm = nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states) # (B, S, H)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.norm(hidden_states + input_tensor)
        return hidden_states

In [6]:
# Custom Debug / Run configurations setup
RUN_MODE = mode_dropdown.value

dataset_config = 'dataset'
model_config = 'model'
training = 'training'

DEBUG_CONFIG =  {
        dataset_config : { 'train' : 5, 'validation': 5},
        model_config: {'num_heads': 6, 'num_layers': 1},
        training: {'track_every' : 10, 'lr': 0.001, 'epochs': 200} # Fix 4: Reduce learning rate to 0.001
    }
RUN_CONFIG = {
        dataset_config : { 'train' : 36718, 'validation': 1000},
        model_config: {'num_heads': 6, 'num_layers': 3},
        training: {'track_every' : 1, 'lr': DEBUG_CONFIG[training]['lr'] / 10, 'epochs': 70 }
    }
RUN_CONFIG_DEFAULTS =  {
    'DEBUG' : DEBUG_CONFIG,
    'RUN' : RUN_CONFIG
  }
CONFIG = RUN_CONFIG_DEFAULTS[RUN_MODE]

reset_seed()

In [7]:
# load and create dataset
train_wikitext_subset = load_dataset('wikitext', 'wikitext-2-v1', split=f'train[:{CONFIG[dataset_config]["train"]}]')
train_text_data = train_wikitext_subset['text']

validation_wikitext_subset = load_dataset('wikitext', 'wikitext-2-v1', split=f'train[:{CONFIG[dataset_config]["validation"]}]')
validation_text_data = validation_wikitext_subset['text']

max_len = 1024
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', max_len=max_len)

class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=max_len):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attention_masks = []
        for text in texts:
            encoded_text = tokenizer(text, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt')
            self.input_ids.append(encoded_text.input_ids)
            self.attention_masks.append(encoded_text.attention_mask)

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx].squeeze(), 'attention_mask': self.attention_masks[idx].squeeze()}

train_dataset = TextDataset(train_text_data, tokenizer)
validation_dataset = TextDataset(validation_text_data, tokenizer)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=data_collator)
validation_dataloader = DataLoader(validation_dataset, batch_size=32, shuffle=False, collate_fn=data_collator)


# BERT model setup
config = BertConfig(
        hidden_size=300, # Embedding dimensions (Fix 1: Change to multiple of num_heads=6)
        max_position_embeddings=max_len,
        type_vocab_size=1,
        num_attention_heads=CONFIG[model_config]['num_heads'],
        hidden_act="gelu",
        intermediate_size=300, # dimension of feedforward expansion (Fix 2: Change to same as hidden_size=60)
        num_hidden_layers=CONFIG[model_config]['num_layers'],
        initializer_range=0.02,
        device=device
)
model = models.Ex5BertForMaskedLM(config=config, embeddings=Ex5BertEmbeddings, selfoutput=Ex5BertSelfAttentionLinearOutput)
model.to(device)

# Optimizer set up
optimizer = optim.AdamW(model.parameters(), lr=CONFIG[training]['lr'])

# Train loop
epochs, train_losses, val_losses = [], [], []
for epoch in range(CONFIG[training]['epochs']):
    epoch_loss, total_masked_tokens, correct_masked_predictions = 0.0, 0, 0
    model.train()
    for batch in train_dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels)
        loss = outputs.loss

        # Fix 3: Clear the gradients before computing gradients w.r.t loss
        optimizer.zero_grad()

        loss.backward()

        epoch_loss += loss.item()
        predictions = outputs.logits
        masked_positions = (labels != -100)
        total_masked_tokens += torch.sum(masked_positions)
        predicted_tokens = torch.argmax(F.softmax(predictions, dim=-1), dim=-1)
        correct_masked_predictions += torch.sum(predicted_tokens[masked_positions] == labels[masked_positions])

        optimizer.step()

    # Train and Validation tracking
    if (epoch + 1) % CONFIG[training]['track_every'] == 0:
      model.eval()
      train_accuracy = correct_masked_predictions.float() / total_masked_tokens.float()
      avg_train_loss = epoch_loss / len(train_dataloader)
      train_perplexity = math.exp(avg_train_loss)

      total_eval_loss = 0
      for batch in validation_dataloader:
        with torch.no_grad():
          input_ids = batch['input_ids'].to(device)
          attention_mask = batch['attention_mask'].to(device)
          labels = batch['labels'].to(device)

          outputs = model(
              input_ids=input_ids,
              attention_mask=attention_mask,
              labels=labels)
          total_eval_loss += loss.item()

          predictions = outputs.logits
          masked_positions = (labels != -100)
          total_masked_tokens += torch.sum(masked_positions)
          predicted_tokens = torch.argmax(F.softmax(predictions, dim=-1), dim=-1)
          correct_masked_predictions += torch.sum(predicted_tokens[masked_positions] == labels[masked_positions])
      avg_eval_loss = total_eval_loss / len(validation_dataloader)
      eval_perplexity = math.exp(avg_eval_loss)
      eval_accuracy = correct_masked_predictions.float() / total_masked_tokens.float()
      
      epochs.append(epoch + 1)
      train_losses.append(avg_train_loss)
      val_losses.append(avg_eval_loss)
      print(
          'Epoch:', '%04d' % (epoch + 1),
          'train cost =', '{:.3e}'.format(avg_train_loss),
          'train ppl =', '{:.3e}'.format(train_perplexity),
          'train acc =', '{:.3f}'.format(train_accuracy),
          'eval cost =', '{:.3e}'.format(avg_eval_loss),
          'eval ppl =', '{:.3e}'.format(eval_perplexity),
          'eval acc =', '{:.3f}'.format(eval_accuracy))

Epoch: 0010 train cost = 7.500e-02 train ppl = 1.078e+00 train acc = 0.089 eval cost = 7.500e-02 eval ppl = 1.078e+00 eval acc = 0.105
Epoch: 0020 train cost = 5.530e-02 train ppl = 1.057e+00 train acc = 0.093 eval cost = 5.530e-02 eval ppl = 1.057e+00 eval acc = 0.052
Epoch: 0030 train cost = 3.419e-02 train ppl = 1.035e+00 train acc = 0.075 eval cost = 3.419e-02 eval ppl = 1.035e+00 eval acc = 0.120
Epoch: 0040 train cost = 2.894e-02 train ppl = 1.029e+00 train acc = 0.278 eval cost = 2.894e-02 eval ppl = 1.029e+00 eval acc = 0.231
Epoch: 0050 train cost = 2.662e-02 train ppl = 1.027e+00 train acc = 0.368 eval cost = 2.662e-02 eval ppl = 1.027e+00 eval acc = 0.400
Epoch: 0060 train cost = 2.354e-02 train ppl = 1.024e+00 train acc = 0.513 eval cost = 2.354e-02 eval ppl = 1.024e+00 eval acc = 0.444
Epoch: 0070 train cost = 2.193e-02 train ppl = 1.022e+00 train acc = 0.477 eval cost = 2.193e-02 eval ppl = 1.022e+00 eval acc = 0.471
Epoch: 0080 train cost = 1.702e-02 train ppl = 1.017e+0

## HINTS

If you're finding the task challenging, hints are available for each bug. We recommend attempting to solve the issues on your own first. Remember, the teaching assistants are also a resource for any questions you have. We're here to assist you!

### HINT 1: Dimension mismatch 1

When using multi-head attention, how should the dimensions of the input relate to the number of heads?

**Answer**: The input dimension should be divisible by the number of heads as the input is split and then individually processed by each head bfore being concatenated back together. As the number of heads is set to 6 in the configuration, I have changed the input dimension to from `50` to `300`.

### HINT 2: Dimension mismatch 2

What is the output dimension of the multi-head self-attention layer?

**Answer**: The output dimension of the multi-head attention layer is the same as the input dimension, thus $B \times S \times D$, where $B$. Thus, we need to set the `intermediate_size` also to `60`.

### HINT 3: Flat loss 

Is your model learning anything? Are your update steps correct? Carefully check the training loop...

**Answer**: No, in the code we are clearing out the gradients using `optimizer.zero_grad()` right after we have compute the gradient w.r.t. the loss value. Therefore, when calling `optimizer.step()` the parameters do not change and we do not learn. To fix the issue, we simply clear out the gradients before the `loss.backward()` call.

### HINT 4: Oscillating loss

Is your loss oscillating or diverging? What parameter settings could be causing this behavior?

**Answer**: The learning rate is too high at `0.05`. We can stabilise the training by reducing the learning rate. A sensible defaul for the `AdamW` optimiser is `0.001` ($1e^{-3}$).

### HINT 5: High perplexity

Is the loss/perplexity at the first epoch too large? What parameter setting influences the training dynamics at the start of training besides the learning rate?

**Answer**: Could be the weight initialisation, the initial embeddings, or the batch size. It seems that the `BertConfig` class from the `transformers` library allows to control the weight initialisation via the `initializer_range` parameter. I have set it to `0.02` as is the default for the BERT model.

### HINT 6: Mediocre results

Is your model training, but does not reach a good performance? Is the model architecture correct? Is there some module/element missing in your layers? Note that you can expect to find all bugs in the notebook itself.

**Answer**: The embedding module does not any positional information. As a quick fix, we can add a positional embedding using a regular nn.Embedding module which learns a positional embedding for each position in the sequence. We can then add the positional embedding to the token embedding before feeding it into the transformer.

### HINT 7: Memory overflow

There might be a layer that uses a large amount of memory... Can you think of which one? What parameter could you change to mitigate this problem?

**Answer**: The embedding layer `Ex5Embedding` uses a lot of memory as it has to store the embeddings over the entire vocabulary. This is a `VOCAB_SIZE * EMBEDDING_DIM` tensor. We can reduce the memory usage by reducing the size of the vocabulary or the embedding dimension.

### HINT 8: Too good to be true...

Considering the small amout of data used to train our model, does the validation performance behave as expected? When does this happen?

**Answer**: We are achieving a training and validation performance of 100% accuracy. This is too good to be true. We are overfitting easily because we are using a very small dataset. We can fix this by using a larger dataset.