Pre-Training BERT from Scratch

In [4]:
%pip install torch
%pip install wandb

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting wandb
  Downloading wandb-0.19.6-py3-none-win_amd64.whl.metadata (10 kB)
Collecting click!=8.0.0,>=7.1 (from wandb)
  Downloading click-8.1.8-py3-none-any.whl.metadata (2.3 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.44-py3-none-any.whl.metadata (13 kB)
Collecting protobuf!=4.21.0,!=5.28.0,<6,>=3.19.0 (from wandb)
  Downloading protobuf-5.29.3-cp310-abi3-win_amd64.whl.metadata (592 bytes)
Collecting pydantic<3,>=2.6 (from wandb)
  Downloading pydantic-2.10.6-py3-none-any.whl.metadata (30 kB)
Collecting pyyaml (from wandb)
  Downloading PyYAML-6.0.2-cp312-cp312-win_amd64.whl.metadata (2.1 kB)
Collecting sentry-sdk>=2.0.0 (from wandb)
  Downloading sentry_sdk-2.22.0-py2.py3-none-any.whl.metadata (10 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.4-cp312-cp312-win_amd64.whl.metadata (10 kB)
Collect


[notice] A new release of pip is available: 25.0 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
#basic imports
import os
import math
import numpy as np
import random
import logging

# Bring in PyTorch
import torch
import torch.nn as nn

# Most of the examples have typing on the signatures for readability
from typing import Optional, Callable, List, Tuple
from copy import deepcopy

# For data loading
from torch.utils.data import Dataset, IterableDataset, TensorDataset, DataLoader
import json
import glob
import gzip
import bz2

#import wandb
import matplotlib.pyplot as plt

# For progress and timing
from tqdm.auto import tqdm, trange
import time
import wandb
import torch.nn.functional as F
wandb.login(key='1dedc34f4e12b8e0d723caafd0a0b53a0205dbb7')


  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'wandb'

# Part 1: Tokenization

To start with, we'll load in one of the tokenizers for the original BERT from the `tokenizers` package. 

In [None]:
################################################################
#                     TODO: YOUR CODE HERE                     #
#
#  1. Create a new BertWordPieceTokenizer using the specified vocab.txt file in the homework. 
#     Make sure that all tokens are lowed cased.
from tokenizers import BertWordPieceTokenizer
tokenizer = BertWordPieceTokenizer("text\vocab.txt")

################################################################

# Test the tokenizer
print(tokenizer.encode("Hello, I am learning to tokenize").tokens)

In [None]:
print(tokenizer.encode("I really wished my model would work.").tokens)
print(tokenizer.encode("Would you please put the pizza on the table?").tokens)
print(tokenizer.encode("Wow your dress looks incrdible!").tokens)
print(tokenizer.encode("I can't wait to drive home tonight on my motorcycle.").tokens)
print(tokenizer.encode("GRRRRAAAHHHH This is so frustrating...").tokens)

# Part 2: Building a Transformer Encoder and BERT

The BERT model is a pre-trained transformer network. Creating a small BERT model will take two big pieces (1) building the transformer itself and (2) writing the code that incorporates the tranformer into BERT and has it set up for training. We'll try to simplify building this by thinking of the different pieces as building blocks that we can put together. Remember, all neural networks are _functions_ and you can compose functions together to make a new function. 

Let's take a look at the overall diagram for Transformers/BERT:
![The diagram of the transformer newtork](https://devopedia.org/images/article/235/5113.1573652896.png)
The trickiest part is the left where we need to deal with the scaled dot-product attention. You've already seen attention though in Homework 2, so some of this should be familiar. 

You'll implement the following pieces to put it all together.

Steps:
1. Embedding: BERT will learn word embeddings that are similar to word2vec's _but_ also incorporate the position of the embedding in the sequence
2. Multi-headed Attention: The core part of the network that learns how much each token should pay attention to all other tokens
3. A Feed-forward Layer: The layer that transformers the attention-combined representations
4. The Transformer Encoder: The unified transformer network that combines attention with the feed-forward layer
5. A BERT Classification Layer: The classificiation part of BERT
6. A BERT Masked Language Modeling (MLM) Layer: The part of BERT that deals with MLM training
7. The overall BERT model: The final BERT model architecture that supports both MLM and Classification

Feel free to read the instructions for all steps in Part 2 before getting started to see how the pieces might fit together. To get everything working, you'll need all parts, which build on each other, so we recommend starting with the first and moving on from there.

## Part 2.1: Embedding Layer
BERT's input embeddings are normally the sum of three embeddings:
- Token Embeddings: The input token embeddings
- Position Embeddings: The position of the token in the sequence
- Token Type Embeddings: The segment (sentence) the token belongs to

The second piece helps BERT learn to distinguish that the same token is in different positions. Remember the attention mechanism works independently of where each of the tokens are; without positional information added to the word embedding, the model can't distinguish between words in different orders!

The third piece was designed for the Next Sentence Prediciton (NSP) task during pre-training. Here, two sentences are provided as input with the special `[SEP]` token between them. The NSP task is a classifiction task based on whether the two sentences did or did not actually follow each other. Just like in word2vec, we would sample random sentences as not-next. The hope for this pretraining task was that it would help BERT learn discourse coherence. However, some later works have shown NSP doesn't actually help that much and training time is probably better spent on doing more MLM, so some more advanced models dropped this.

For simplicity, in Homework 3, you only need to deal with token embeddings and positional embeddings, but do not need to deal with token-type embeddings.

**NOTE:** When talking about embeddings, people (and these instructions) will talk about various things being embedded, e.g., words, tokens, wordpieces, subwords, etc. In practice, these are all based on whatever the tokenizer is producing, and BERT (and you) is agnostic to what is actually being embedded. When the instructions talk about "token embeddings" this are still just the output of the BERT WordPiece tokenizer.

In [None]:
class BertPositionalEmbedding(nn.Module):
    def __init__(self, vocab_dim: int, 
                 hidden_dim: int = 768, 
                 padding_idx: int = 0, 
                 max_seq_length: int = 512):
        
        super().__init__()

        '''
        Initialize the Embedding Layers
        '''

        ################################################################################
        #                             TODO: YOUR CODE HERE                             #
        #                                                                              #
        #  1. Create two Embedding objects for the words and the positions.            #
        #     For the word embeddings keep track of which index is the padding index.  #        
        # 
        self.word_embedding = nn.Embedding(vocab_dim, hidden_dim, padding_idx=padding_idx)
        self.position_embedding = nn.Embedding(max_seq_length, hidden_dim)                                                                             #
        ################################################################################

    def forward(self, token_ids: torch.Tensor, 
                ) -> torch.Tensor:
        
        '''
        Define the forward pass of the Embedding Layers
        '''
        
        ############################################################################
        #                               TODO: YOUR CODE HERE                       #
        #                                                                          #
        # 1. Look up the relevant token embeddings from the word_embeddings layer  #
        words = self.word_embedding(token_ids)
        seq_ids = torch.arange(token_ids.size(1), dtype = torch.long, device=token_ids.device)
        seq_ids = seq_ids.unsqueeze(0).expand_as(token_ids)
        pos = self.position_embedding(seq_ids)
        # 2. Return the sum of the token embeddings and the positional embeddings  #
        return words + pos
        #                                                                          #
        ############################################################################

## Part 2.2: Multi-Head Attention

This is the trickiest part of the homework where we implement the attention part of the transformer. Let's take a look at the attention:

![The attention network](https://www.tutorialexample.com/wp-content/uploads/2021/03/The-structure-of-Multi-Head-Attention.png)




In [None]:
# Optional TODO: Try implementing attention with some embeddings to get a sense of the core parts
# Note that this is not a complete implementation of the attention mechanism in the transformer,
# which also includes some scaling and masking operations as well as dealing with multiple heads.

# Generate some Q, K, V embeddings for a batch of 10 sequences of length 4 with embedding size 7
q_emb = torch.randn(10, 4, 7)
k_emb = torch.randn(10, 4, 7)
v_emb = torch.randn(10, 4, 7)

# Start by computing the dot product of the Q and K embeddings
dot_prod = (q_emb @ k_emb.transpose(-1, -2)) 

# Then apply a softmax to get the attention weights
attn = nn.functional.softmax(dot_prod, dim=-1)

# This should be torch.Size([10, 4, 4]) --- i.e., how much each word (4 words) pays attention to each other word
print(attn.shape)

# Now compute the weighted words by multiplying the attention weights by the V embeddings
weighted_output = attn @ v_emb

# This should be torch.Size([10, 4, 7]) --- i.e., the same shape as our inputs, but the embedings are weighted combinations of the V embeddings!
print(weighted_output.shape)

In [None]:
# Optional TODO: How to implicitly represent attention weights in a single layer

# Let's start with a single projection of the input embedding for Q for a sequence of 3 words with embedding size 8
# NOTE: we'll use randint here to make it easier to see the reshaping when we print the tensors
q_emb = torch.randint(10, (3, 8))
print(q_emb)
# If we have two attention heads we can reshape this tensor so that the first dimension is the number of heads
q_s = q_emb.shape
# Should be torch.Size([3, 8])
print(q_s)
n_heads = 2
multihead_q_emb = q_emb.view(n_heads, q_s[0], q_s[1]//n_heads)

# Should be torch.Size([2, 3, 4])
print(multihead_q_emb.shape)

# Check that the reshaping is correct by looking at which values when where
print(multihead_q_emb)

# Remember: Each head has its own attention! So we'll need to use this kind of reshaping for K and V so we can evantually compute
# the per-head specifici attention weights. 

# Optional TODO: try creating the K projection and re-shaping it and then the calculate its attention weights using the logic from the above cell
# Note that you'll now have another dimension in the tensor! 

In [None]:
class MultiHeadedAttention(nn.Module):
    def __init__(self, hidden_size: int, num_heads: int):
        '''
        Arguments:
        hidden_size: The total size of the hidden layer (across all heads)
        num_heads: The number of attention heads to use
        '''

        super().__init__()

        '''
        Initialize the Multi-Headed Attention Layer
        '''

        ###########################################################################################################################
        #                                                     TODO: YOUR CODE HERE                                                #
        #
        self.num_heads = num_heads
        # 1. Figure out how many dimensions each head should have      
        self.head_dim = hidden_size // num_heads
        # 2. Create linear layers to turn the input embeddings into the query, key, and value projections  
        self.K = nn.Linear(hidden_size, hidden_size)
        self.Q = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, hidden_size)
        # 3. Calculate the scale factor (1 / sqrt(per-head embedding size))
        self.scale = 1/ np.sqrt(self.head_dim)                                                       
        #                                                                                                                         #                                            
        # NOTE: Each of the Q, K, V projections represents the projections of *each* of the heads as one long sequence.           #
        #       Each of the layers is implicitly representing each head in different parts of its dimensions. E.g., if you        #
        #       have 4 heads and 16 dimensions, the first 4 dimensions are the first head, the second 4 dimensions are            #
        #       the second head, etc.                                                                                             #                        
        ###########################################################################################################################

    def forward(self, embeds: torch.Tensor, 
                mask: Optional[torch.Tensor] = None
                ) -> torch.Tensor:
        '''
        Arguments:
        embeds: The input embeddings to compute the attention over
        mask: A boolean mask of which tokens are valid to use for computing attention (see collate below)
        '''

        #####################################################################################################################################################
        #                                                   TODO: YOUR CODE HERE                                                                            #
        #                                                                                                                                                   #
        # This is the hard part of the assignment where you'll need to implement the multi-headed attention mechanism.                                      #
        # The cell above has the core logic for the attention mechanism to get you started with getting the tensor                                          #
        # shapes lined up correctly. We recommend working through that manually with some small examples to get a sense                                     #
        # of what is happening at each step.                                                                                                                #
        #                                                                                                                                                   #
        # One of the key parts to work through is how to handle the shapes of the different embeddings. For attention to work                               #
        # and be efficient, we'll only need to do a few matrix multiplications and a softmax. The key is to figure out how to                               #
        # do this by getting the tensors into the right shapes. We strongly recommend trying to write comments at each step                                 #
        # That describe the shape of the tensors at each step in terms of what each representings.   This will help you understand                          #
        # what is happening and debug.                                                                                                                      #
        #                                                                                                                                                   #
        # We recommend using notation like the following which you'll also see in papers and blogs:                                                         #
        #  - B: Batch size                                                                                                                                  #
        #  - H: Number of heads                                                                                                                             #
        #  - T: Sequence length                                                                                                                             #
        #  - D: Embedding size                                                                                                                              #
        #                                                                                                                                                   #
        # 1. Figure out what are the dimensions of the input embeddings and which dimensions                                                                #
        #    represent what (e.g., the batch size, sequence length, etc.)
        B = embeds.shape[0]
        #H = self.num_heads
        T = embeds.shape[1]
        D = embeds.shape[2] 
        # 2. Project the input embeddings into the Q, K, and V spaces
        Q = self.Q(embeds).view(B, self.num_heads, T, D//self.num_heads)
        K = self.K(embeds).view(B, self.num_heads, T, D//self.num_heads)
        V = self.V(embeds).view(B, self.num_heads, T, D//self.num_heads)

        # 3. Compute the attention weights from the Q and K projections (be sure to scale the dot product by the scale factor!)
        dot_prod = self.scale*(Q @ K.transpose(-2, -1)) 
        # 4. *If their is a mask*, apply the mask to the attention weights where masked values are set to -inf  
        if mask != None:
            expanded_mask = mask.unsqueeze(1).unsqueeze(1)  # Shape: (batch size, 1, sequence length, 1)
            #mask = expanded_mask.expand_as(attn)  # Shape: (batch size, number of heads, sequence length, sequence length)  
            dot_prod = dot_prod.masked_fill(expanded_mask == 0, float('-inf'))
        attn = F.softmax(dot_prod, dim=-1)
        # 5. Compute the weighted sum of the V embeddings using the attention weights    
        attention = attn @ V
        # 6. Return the re-weighted output values *and* the attention weights in the shape (Batch, Heads, SeqLen, SeqLen)                                   #
        #    We'll use the attention weights for visualization later                                                                                        #
        return attention.view(B, -1, D), attn.view(B, self.num_heads, T, T)                                                                                                                                               #
        # NOTE: when we say dimension, we're referring to how many axes are in a tensor. E.g., a vector                                                     #
        #       is a one-dimensional tensor, a matrix is a two-dimensional tensor, etc. Each dimension has a size too (number of components),               #
        #       e.g., a 3x4 tensor has 2 dimensions, where the first has 3 elements/components and the second has 4 elements/components.                    #
        # NOTE: You will probably want to use the .view() method to reshape the input embeddings with respect to the number of heads                        #
        #       so that you can get a tensor where one dimension corresponds to each head                                                                   #
        # NOTE: You may want to use the .transpose() method to swap dimensions (e.g., to move the heads dimension to the front)                             #
        # NOTE: Check out masked_fill for applying the mask to the attention weights                                                                        #
        # NOTE: You will not need to concatenate anything in practice because the Q, K, V projections already represent concatenated head-specific values   #
        #    in the right dimensions.                                                                                                                       #
        # HINT: You can reshape any tensor with view to "add a dimension" to it.                                                                            #
        #    E.g., if you have a tensor with shape (B, T, D) and you want to add a dimension to the front, you can do tensor.view(B, 1, T, D)               #
        #####################################################################################################################################################

## Part 2.3 Define the Feed-Forward Layer

The feed forward layer is a simple two-layer feed forward network (FFN) with an activation function between the layers. This network usually follows the multi-headed attention output that allows its content representation to be tranformed and aggregated. Below, we'll make a function that return this network using `nn.Sequential` which takes in a tuple or list of layers and activation functions where any input is passed through each function in order and then the output is returned.

In a tranfsormer, typically the FFN is wider than the embedding size, usally with an expansio factor of 4. This increased number of neurons allows the model to capture more interactions between dimensions of the embedding.


In [None]:
def feed_forward_layer(
    hidden_size: int, 
    feed_forward_size: Optional[int] = None, 
    activation: nn.Module = nn.GELU()
):
    '''
    Arguments:
      - hidden_size: The size of the input and output of the feed forward layer. 
      - feed_forward_size: The size of the hidden layer in the feed forward network. If None, defaults to 4 * hidden_size. This size
        specifies the size of the middle layer in the feed forward network.
      - activation: The activation function to use in the feed forward network

    Returns: 
    '''
    ################################################################
    #                     TODO: YOUR CODE HERE                     #
    # Implement the feed forward layer as described in the slides  #
    # The feed forward layer is a simple three-layer neural network#
    # with an activation function.                                 #
    if feed_forward_size == None: feed_forward_size = 4*hidden_size
    return nn.Sequential(nn.Linear(hidden_size, feed_forward_size), activation, nn.Linear(feed_forward_size, hidden_size))
    # NOTE: It maps from hidden_size to feed_forward_size and then #
    #       back to hidden_size.                                   #
    ################################################################

## Part 2.4 Building a Transformer Block as an Encoder Layer

Let's finish putting together our tranformer pieces into a single network that (1) computes the self-attention to get contextualized word representations and (2) passes those representations through our feed-forward neural network layers. 

During training, we'll also add some probability of using [dropout](https://machinelearningmastery.com/using-dropout-regularization-in-pytorch-models/) which is also implemented in pytorch. Dropout will randomly zero-out some values when training so the model learns to be robust to the value of any one neuron. In practice, when you switch between `train()` and `eval()`, under the hood, this will turn on/off things like dropout automatically!

In [None]:
class TransformerEncoderLayer(nn.Module):

    def __init__(
        self,
        hidden_size: int = 256, # NOTE: normally 768, but keep it small for homework
        num_heads: int = 12,
        dropout: float = 0.1,
        activation: nn.Module = nn.GELU(),
        feed_forward_size: Optional[int] = None
    ):
        super().__init__()

        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        # Now we can put it all together to create one layer of the    #
        # transformer encoder.                                         #
        #                                                              #
        # 1. Create a MultiHeadedAttention layer with the specified    #
        #    number of heads.                                          #
        self.mha = MultiHeadedAttention(hidden_size, num_heads)
        # 2. Create a feed forward layer with the specified activation #
        #    function.                                                 #
        self.ffl = feed_forward_layer(hidden_size, feed_forward_size, activation)
        # 3. Save the hidden_size, dropout, and feed_forward_size      #
        #    as attributes.                                            #
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.feed_forward_size = feed_forward_size
        
        # 4. Define a forward method that takes in an input tensor     #
        #    and an optional mask tensor and returns the output tensor #
        #    and the attention weights that go through the multi-      #
        #    headed attention layer and the feed forward layer.        #
        ################################################################

    def maybe_dropout(self, x: torch.Tensor) -> torch.Tensor:

        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        # Implement the dropout layer to be used in the forward pass   #
        # of the transformer encoder layer (if dropout was specified)  #
        if self.training and self.dropout > 0:
            return F.dropout(x, p=self.dropout, training=True)
        return x
        ################################################################
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        '''
        Returns the output of the transformer encoder layer and the attention weights from the self-attention layer
        '''
        
        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        #
        # 1. Pass the input tensor through the self-attention layer 
        output, attn_weights = self.mha(x, mask)
        # 2. Call maybe_dropout on the output of the self-attention
        output = self.maybe_dropout(output) + x 
        # 3. Pass the output of the self-attention through the feed forward network
        output = self.ffl(output) + output
        # 4. return the output of the feed forward network and the attention weights
        return output, attn_weights   
        #
        ################################################################

## Part 2.5: Masked Language Modeling Head

A BERT model usually comes with multiple "heads" that are used for different tasks. One of these heads is used for the masked language model (MLM) task during pretraining. 
        
In practice, a "head" is just a linear layer that maps the hidden size to the output size. This linear layer allows the model to adapt the token representations to a particular task, while keeping the core transformer model parameters the same across tasks. A model could have multiple heads for different tasks.  In our implementation, we'll have two heads, one for the masked language modeling task and a second for classification.

The MLM head maps the contextualized token embedding to the vocabulary, so if we have some embedding $e_i$, we're learning a weight matrix $W_{mlm}$ of size $|e_i| \times |V|$. In BERT pre-training, this weight matrix is said to be _tied_ to the input embeddings; in practice, that means we use the same weights from the `Embedding` (!), except this time the weights are used as a linear layer! (NOTE: This means you're not defining a separate linear layer). You might also see this called "parameter sharing" where the same parameters (i.e., weights) are used in different parts of the neural network. 

In [None]:
class MLMHead(nn.Module):
    def __init__(self, word_embeddings: nn.Embedding):
        '''
        Arguments:
            word_embeddings: The word embeddings to use for the prediction
        '''

        super().__init__()
        self.word_embeddings = word_embeddings

    def forward(self, x):
        '''
        x: The input tensor to the MLM head containing a batch of sequences of
           contextualized word embeddings (activations from the transformer encoder 
           layers)
        '''
        
        ################################################################
        #                 TODO: YOUR CODE HERE                         #
        #                                                              #
        # The MLM head is used to predict the original token from      #
        # the masked token. The prediction is over the whole 
        # vocabulary so we'll need an activation. To make this work,
        # we'll generate a tensor the length of the vocabulary size
        # that we can push through a softmax to get the probabilities
        # of each token being present.
        return (x @ self.word_embeddings.weight.transpose(0,1))
        #                                                              #
        # NOTE: The head should not have an activation function.       #
        # NOTE: The head should be tied to the input embeddings (i.e., #
        #       the head should map the embeddings to the vocab size). #
        #       In other words, the MLM head directly predicts the     #
        #       token from the learned embeddings instead of the       #
        #       last hidden states.                                    #
        # HINT: You can get the tensor of the word embeddings to use   #
        #       for prediction from the word_embeddings object.        #
        # HINT: Desipte all this writing, this function is only a 
        #       single line of code.                                   #
        ################################################################


## Part 2.6: Classification Pooler Head

When we pre-train, we'll learn a special embedding called `[CLS]` that is the first embedding in any sequences and loosely approximates the overall meaning/semantics of the input text. Frequently, when we fine-tune a BERT classifier, we're using this `[CLS]` token as the summary of the input and updating the weights accordingly. 

Here, we'll create a network that will add a `Linear` layer on top of the `[CLS]` token which can be fine-tuned to do classification later. This linear layer allows use to adapt/project the `[CLS]` to a representation more suitable for classification.  This is the classification head of BERT.

We'll use this network later when defining the `BERT` class below.

In [None]:
class Pooler(nn.Module):
    def __init__(self, hidden_size: int = 768):
        super().__init__()

        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        # The BERT model usually uses the first token of the sequence  #
        # to represent the entire sequence (this is the [CLS] token).  #
        #                                                              #
        # The pooler layer is a simple linear layer that maps the      #
        # representation of the [CLS] token to the hidden size.        #
        self.linear = nn.Linear(hidden_size, hidden_size)
        #                                                              #
        # This pooled representation is used as the input to the       #
        # classification layer defined in the later cell.              #
        ################################################################

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        #
        # 1. Pass the [CLS] embedding through the dense layer 
        output = self.linear(x[:,0])
        # 2. Pass the output of the dense layer through the activation
        active = torch.sigmoid(output)
        # 3. Return the output of the activation
        return active
        # OPTIONAL TODO: One other way you can represent the contents of 
        #             the entire sequence is to use the mean of all the non-special
        #             token embeddings (also known as "mean pooling"). 
        #             You can try implementing that here and add a flag to the
        #             Pooler __init__ function to switch between the two approaches.
        #
        ################################################################

## Part 2.7: The BERT network!

Finally! We have arrived at putting all the pieces together into a single neural network. Our BERT model will have the main attention plus feed-forward network components and then two heads: one for MLM and one for Classification. At model creation time, we'll specify which "mode" the model should be in (MLM or classification).

This implementation is a good reminder that neural networks (e.g., `nn.Module`) are just functions on inputs. We can compose and stack them together to get some really useful (and cool) outputs as a result.

In [None]:
class BERT(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        padding_idx: int = 0,
        hidden_size: int = 768,
        num_heads: int = 12,
        num_layers: int = 12,
        dropout: float = 0.1,
        activation: nn.Module = nn.GELU(),
        feed_forward_size: Optional[int] = None,
        mode: str = "mlm",
        num_classes: Optional[int] = None
    ):
        '''
        Defines BERT model architecture. Note that the arguments are the same as the default
        BERT model in HuggingFace but we'll be training a *much* smaller model for this homework.

        Arguments:
        vocab_size: The size of the vocabulary (determined by the tokenizer)
        padding_idx: The index of the padding token in the vocabulary (defined by the tokenizer)
        hidden_size: The size of the hidden layer and embeddings in the transformer encoder
        num_heads: The number of attention heads to use in the transformer encoder
        num_layers: The number of layers to use in the transformer encoder (each layer is a TransformerEncoderLayer)
        dropout: The dropout rate to use in the transformer encoder (what % of times to randomly zero out activations)
        activation: The activation function to use in the transformer encoder
        feed_forward_size: The size of the hidden layer in the feed forward network in the transformer encoder. If None, defaults to 4 * hidden_size
        mode: The mode of the BERT model. Either "mlm" for masked language modeling or "classification" for sequence classification
        num_classes: The number of classes to use in the classification layer.
        '''


        super().__init__()

        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        # Now we can put it all together to create the BERT model.     #
        #                                                              #
        # A BERT model is just a stack of transformer encoder layers   #
        #                                                              #
        # followed by a pooler layer and a classification layer.       #
        # 1. Create a BertPositionalEmbedding layer with the specified #
        #    vocab size, hidden size, and padding index.               #
        self.bpe = BertPositionalEmbedding(vocab_size, hidden_size, padding_idx)
        # 2. Create a stack of transformer encoder layers with the     #
        #    specified number of layers, hidden size, number of heads, #
        #    dropout, activation function, and feed forward size.      #
        self.tel_layers = nn.ModuleList([
            TransformerEncoderLayer(hidden_size, num_heads, dropout, activation, feed_forward_size) 
            for _ in range(num_layers)])
        # 3. Create an MLMHead layer with the specified vocab size and #
        #    padding index.                                            #
        self.mlm = MLMHead(self.bpe.word_embedding)
        # 4. Create a Pooler layer with the specified hidden size.     #
        self.pooler = Pooler(hidden_size)
        # 5. Create a classification layer with the specified number   #
        #    of classes.                                               #
        #
        if num_classes != None:
            self.linear = nn.Linear(hidden_size, num_classes)
        self.mode = mode
        # HINT: you can use nn.ModuleList to stack layers.  
        #
        ################################################################

    def forward(
        self, 
        x: torch.Tensor, 
        mask: Optional[torch.Tensor] = None, 
    ) -> torch.Tensor:
        '''
        arguments:
        x: The input token ids
        mask: The attention mask to apply to the input (see the collate function below)
        '''
        
        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        #
        # 1. Get the embeddings for the input token ids
        embeds = self.bpe(x)
        # 2. Calculate the attention weights for each layer (be sure to save the attention weights to return)
        weights = []
        for layer in self.tel_layers:
            embeds, attn_weights = layer(embeds, mask=mask)
            weights.append(attn_weights)
        # 3a. If the mode is "mlm", pass the embeddings through the MLM head and return the output
        output = None
        if self.mode == 'mlm':
            output = self.mlm(embeds)
        # 3b. If the mode is "classification", pass the embeddings through the classification head
        elif self.mode == "classification":
            pooled = self.pooler(embeds)
            output = self.linear(pooled)
        # 4. Return the output (from the relevant head) and the attention weights
        return output, torch.stack(weights)
        #
        ################################################################

    def init_layer_weights(self, module):
        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        #
        #  Initialize the weights of the model with mean 0 and std 0.02
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        # 
        ################################################################


In [None]:
'''
We can verify that the model is working by running a quick test.
'''

sentence = "The quick brown fox jumps over the lazy dog."
tokens = tokenizer.encode(sentence)
# to tensor
token_ids = torch.tensor(tokens.ids)
bert = BERT(vocab_size=tokenizer.get_vocab_size(), 
            hidden_size=768, 
            num_heads=4, 
            num_layers=2,
        )

attention_mask = (token_ids != 0).float()

output, attn = bert(token_ids.unsqueeze(0), attention_mask.unsqueeze(0))
print(tokenizer.get_vocab_size())
print(output.shape)

# Part 3: Training

After all that implementation, it's time to pre-train our BERT model. We'll focus specifically on MLM training and then in a separately notebook you'll used the pre-trained BERT for Parts 5 and 6. 

To pre-train, we'll need to accomplish three pieces:
1. First, we'll need to create a `Dataset` class that says how to load and process our text data for MLM training. 
2. Second, we'll need to create a `collate` function that tells the `DataLoader` how to combine multiple training examples from our dataset into a batch. The `collate` function is critical because not all of our input texts have the same size (different sequence lengths) which will create a wrinkle for giving a model a single `Tensor`
3. Third, we'll write the core training loop.

Each of the pieces below goes over more details. As you progress as a practitioner, you'll frequently need to write `Dataset` and `collate` functions for more bespoke kinds of training tasks. This part of the assignment is intended to show you how to do some simple implementations that you can re-use later. You'll likely reuse some/all of this code in Parts 5 and 6 when setting up these functions for classification!

## Part 3.1: Create A Dataloader for Masked Language Modeling

A `Dataset` is kind of like a glorified `list` object in python. The Dataset class helps us to load and process the data, which providing functionality that lets us access the data

The main function of the Dataset class is to get the number of samples and to get a sample from the dataset. The core functionality you'll need to implement in this part of the assignment is the following:
- Load the data and tokenize it using the tokenizer.
- The __len__ method to return the number of samples in the dataset.
- The __getitem__ method to return a sample from the dataset.

As a sidenote, the `Dataset` class provides a very important abstraction for training and more sophisticated `Dataset` implementations will do usfeul things like keep only some of the dataset in memory and proactively fetch data from desk to keep the overall memory footprint low. Others may do pre-processing on the fly to avoid having large `Tensor` objects in memory (e.g., loading images and preparing them for image-based learning). Just think of how large some datasets might get---as a practitioner, these implementations are critical for efficient training!

In [None]:
class MLMDataset(Dataset):
    def __init__(self, tokenizer, data: List[str], max_seq_length=128, mlm_probability=0.15):

        ##################################################################################################################
        # TODO: YOUR CODE HERE 
        # 
        # 1. Store the arguments as fields
        self.tokenizer = tokenizer
        self.data = data
        self.max_seq_length = max_seq_length
        self.mlm_probability = mlm_probability
        self.tokens = []
        #                                                                                                                #                                    #
        ##################################################################################################################

    def __len__(self):
        return len(self.data)
    
    def tokenize(self):
        '''
        Tokenizes the text in self.data, performing any preprocessing and storing the tokenized data in a new list.
        '''

        ##################################################################################################################
        # TODO: YOUR CODE HERE 
        #                                                                                                                #                                                     #
        tokenized_data = []
        for text in self.data:
            # 1. Tokenize the data using the tokenizer.  
            token_ids = self.tokenizer.encode(text).ids  
            # 2. Truncate the sequence to the maximum sequence length.                                                       #
            token_ids = token_ids[:self.max_seq_length]
            # 3. Add the tokenized data to a list.  
            tokenized_data.append(token_ids)
        self.tokenized_data = deepcopy(tokenized_data)
        return tokenized_data
        # NOTE: To save memory, you can delete self.data after tokenizing since you'll have the copy of 
        #       the tokenized data (as ids) and won't need the raw text later.
        del self.data
        #
        
        ##################################################################################################################

    def __getitem__(self, idx):
        '''
        Returns the list of the token ids of an instance in the dataset and a list of the labels for MLM (one label per token).
        '''

        ################################################################
        #                     TODO: YOUR CODE HERE                     #
        #
        # 1. Get the tokenized data at the specified index
        token_ids = deepcopy(self.tokenized_data[idx])
        # 2. Create the mask (i.e., which words the model has to predict). Special tokens are never masked.
        #    Non-masked tokens should be set to -100, which will be ignored in the loss function. 
        #    The masked tokens should be set to the original token ids. Use the specified masking probability.  
        labels = []
        for i, token_id in enumerate(token_ids):
            #exclude special tokens 
            if token_id == 101 or token_id == 102 or token_id == 0:
                labels.append(-100)
            else: 
                if np.random.uniform() < self.mlm_probability:
                    labels.append(token_id)
                    token_ids[i] = tokenizer.token_to_id('[MASK]')
                else:
                    labels.append(-100)

        return token_ids, labels     
        # 
        # Hint: Use the tokenizer's functions to get IDs as needed
        ################################################################

In [None]:
'''
Verify that the dataset is working by running a quick test.
'''

# create a fake dataset using 26 aphabets, 1000 sentences, 10-20 words per sentence randomly
data = [' '.join([chr(97 + i) for i in range(random.randint(10, 20))]) for _ in range(1000)]

dataset = MLMDataset(tokenizer, data)
dataset.tokenize()

# get the first item
input_ids, labels = dataset[0]
print('input_ids:', input_ids)
print('labels:', labels)

In [None]:
## tet with one sentence
data = ["The quick brown fox jumps over the lazy dog.",
        "The quick brown fox jumps over the lazy dog.",
        "The quick brown fox jumps over the lazy dog.",]
dataset = MLMDataset(tokenizer, data)
dataset.tokenize()
print(dataset[0])


## Part 3.2: Create A `collate` function to prepare a batch for training

Once the dataset is ready, we can use the `DataLoader` class  to load the data from the `Dataset` class, much like you did for Homework 2.  Remember that the Dataset class has the __getitem__ method  that returns the input_ids and the labels.
    
Like Homework 2, we'll want to train using a _batch_ of items. This is good for two reasons. First, batching helps us learn from multiple examples at the same time, so the gradient is a bit smoother. Batching provides a good trade-off between SGD (one item at a time) and full GD (all items at once). Second, and perhaps more importantly, batching allows us to maximize the throughput of our computing resources. Depending on the hardware, some matrix operations are the same speed for different sized matrices, so if we can get more examples used to trained per step, this reduces the overall number of steps. You may have seen this in Homework 2 where increasing the batch size dropped the training time, up to some point. 

By default, the `DataLoader` will randomly sample $b$ items from `Dataset` where $b$ is the batch size and turn those into a single `Tensor` to pass as input to the model. To create the batch itself, the `collate` function will tell the `DataLoader` how to turn multiple items into a single `Tensor`.

However, we have a major wrinkle here. When we train BERT for MLM over sequences, not all the sequences have the same length. A batch itself is represented as a `Tensor`. When all the instances in the batch have the same length, we can concatenate/stack them. For example, if we had 10 instances of sequences of length 5, we can create a tensor that is size (10, 5). However, if some of those sequences have different lengths, we no longer can create a Tensor with a single length dimension! What to do?

To solve this, we'll need to write the `collate` function so that it makes all the sequences have the same length. Typically this is done by adding a special `[PAD]` token to the sequences so that they all have the same number of tokens. However, this extra token is meaningless!  If we don't recognize this, then our model will learn to predict `[PAD]` tokens (yikes!). Therefore, we also need to create an _attention mask_ that tells us which tokens to ignore in the input because they are padding. 

In [None]:
# create a fake dataset using 26 aphabets, 1000 sentences, 10-20 words per sentence randomly
data = [' '.join([chr(97 + i) for i in range(random.randint(10, 20))]) for _ in range(1000)]

dataset = MLMDataset(tokenizer, data)
dataset.tokenize()

# get the first item
input_ids, labels = dataset[0]
print('input_ids:', input_ids)
print('labels:', labels)

In [None]:
def collate_fn(batch: List[Tuple[List[int], List[int]]]):
    '''
    A function that takes a list of instances in the dataset and collates them into a batch.
    '''

    ################################################################
    #                    TODO: YOUR CODE HERE                      #
    # 1. Separate the input_ids and the labels into two lists      #
    input_ids_list = [torch.tensor(i[0]) for i in batch]
    labels_list = [torch.tensor(i[1]) for i in batch]
    # 2. Pad both lists using the tokenizer's padding value so 
    #    that all lists are the same length.
    padded_input_ids = torch.nn.utils.rnn.pad_sequence(input_ids_list, batch_first=True, padding_value=0)
    padded_labels = torch.nn.utils.rnn.pad_sequence(labels_list, batch_first=True, padding_value=-100)  # -100 is used for padding in MLM loss
    
    # 3. Create a boolean attention mask for the input_ids which specifies
    #    which tokens are non-padded elements.
    mask = (padded_input_ids != 0)

    # 4. Return the padded input_ids, the attention mask, and the padded labels.
    return padded_input_ids, mask, padded_labels
    # NOTE: Look at nn.utils.rnn.pad_sequence
    #
    ################################################################

In [None]:
# test the collate function
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# All of the items in the batch should have the same shape!
for input_ids, attention_mask, labels in dataloader:
    print(input_ids.shape)
    print(attention_mask.shape)
    print(labels.shape)
    break

## Part 3.3: Train BERT for MLM!

Once we have our `Dataset` and `collate` function, it's time to train the pre-model the BERT model for MLM. Working on your laptop to start, try training for 1-2 epochs on the `med` dataset. There's no guarantee the model will learn anything useful, but we can do some experiments later in Part 4 to take a look. 

The core steps for training are the following:
1. Define your model
2. Set up `wandb` for tracking its progress (this will be useful for monitoring when running on Great Lakes)
3. Define the optimizers, learning rate, etc.
4. Define the core training loop

The code won't look too dissimilar from past pytorch training loops in Homeworks 1 and 2.

Training times can vary, but on an M1 Mac and the default hyperparameters, one epoch takes ~50min on the large dataset with "mps" and 6 hours on "cpu". The model and training seem to fit at 12GB of memory. 

For training, we recommend trying to run either the medium or large for a few epochs. If your CPU is slower, try medium just for one epoch to get a sense of what it's learned in Part 4.  For getting a sense of whether the model is working, we strongly recommend testing your CPU-trained model with the masked language modeling task in Task 4.4 to see whether it can correctly fill in common words. The most similar words aren't always a good indicator that it's working.

Don't worry if you can't train too long on your own machine. Once you've gotten the model working (e.g., some preliminary analysis in Part 4 looks "okay" -- doesn't have to be great), it's time to convert this to a script and run it on Great Lakes in Part 3.5 (more details there). You'll use the advanced GPUs on the cluster to go through 5 epochs (or more) during training, which will give you a good enough model that you can use it for the rest of the homework in Parts 4 and 5.


#### VERY OPTIONAL PARTS:

If you are feeling _really_ adventurous, you can try to speed up your model by trying a few of the much fancier things in pytorch. In practice, you should not need these for the homework. However, they can be fun to explore even with a CPU, though they make the biggest impact if you have access to a GPU too

- Try using [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) to optimize the `nn.Module` code (this tries to pre-compile the computation graph)
- Rather than train with 32-bit floating point, use [amp](https://pytorch.org/blog/what-every-user-should-know-about-mixed-precision-training-in-pytorch/) to do mixed-precision training (i.e., fewer bits and faster). If you want to test this on Great Lakes, the GPUs there support "fp16" which will greatly speed up training. At the moment, `amp` isn't supported on "mps" devices, though it might show up [soon](https://github.com/pytorch/pytorch/issues/88415).
- Implement an alternative training loop using [`accelerate`](https://github.com/huggingface/accelerate) and try using mixed precision (fp8, fp16, bf16) training

In [None]:
# Let's generate with our real data!

#review_data_path = './reviews-word2vec.med.txt' # <- for sanity checking / debugging
#review_data_path = './reviews-word2vec.large.txt.gz' # <- for CPU pre-training and validating
review_data_path = './reviews-word2vec.larger.txt.gz' #<- for GPU pre-training and validating (Part 3.5)


# NOTE: when you eventually deploy this code to Great Lakes, you'll need to use the larger dataset 
# (see the PDF for notes/details)

ofunc = gzip.open if review_data_path.endswith('gz') else open
with ofunc(review_data_path, 'rt') as f:
    reviews = f.readlines()
    reviews = [review.strip() for review in reviews]

dataset = MLMDataset(tokenizer, reviews)
dataset.tokenize()

# get the first item
print(dataset[0])

In [None]:
# test the collate function
batch_size = 8
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

# All of the items in the batch should have the same shape!
for input_ids, attention_mask, labels in dataloader:
    print(input_ids.shape)
    print(attention_mask.shape)
    print(labels.shape)
    break

#Should be torch.Size([8, 119]) for large

In [None]:
'''
Now, lets put it all together in the training loop.
'''

# check if gpu is available
device = 'cpu' 
if torch.backends.mps.is_available():
    device = 'mps'
if torch.cuda.is_available():
    device = 'cuda'
print(f"Using '{device}' device")

################################################################
#         TODO: YOUR CODE HERE    
#
# 1. Define the loss function, optimizer, and the BERT model. Use the model hyperparameters from the PDF.
model = BERT(vocab_size=tokenizer.get_vocab_size(),
        feed_forward_size= 256,
        hidden_size= 128,
        num_layers= 2,
        num_heads= 4)
model.to(device)
model.train()

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
f_loss = nn.CrossEntropyLoss(ignore_index=-100)

# 2. Specify the hyperparameters for training (e.g., learning rate, batch size, etc.)
batch_size = 16
num_epochs = 10

# 3. Initialize wandb for logging.
wandb.init(project='EECS595-BERT', name=f'experiment-{time.time()}')
wandb.watch(model, log_freq=100)
#
################################################################


losses = []
print("Training")
################################################################
#         TODO: YOUR CODE HERE       #
# THe training loop is where we train the model.              #
for epoch in trange(num_epochs, desc="Epoch"):
    #                                                             #
    # You should implement the following:                         #
    # 1. Load the input_ids, attention_mask, and labels to the    #
    #    device.                                                  #
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
    for step, (input_ids, attention_mask, labels) in enumerate(tqdm(dataloader, position=1, leave=True, desc="Step")):
        #send to device
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)
        # 2. Zero the gradients of the optimizer.                     #
        optimizer.zero_grad()
        # 3. Get the output from the BERT model.                      #
        outputs, attn = model(input_ids, attention_mask)
        # 4. Calculate the loss using the output and the labels.      #
        #labels = labels.unsqueeze(2).expand(outputs.shape[0], outputs.shape[1], outputs.shape[2])
        outputs = outputs.view(-1, outputs.size(-1))
        labels = labels.view(-1)
        loss = f_loss(outputs, labels)
        # 5. Backpropagate the loss.                                  #
        loss.backward()
        # 6. Update the optimizer.                                    #
        optimizer.step()
        # 7. Log the loss to wandb.                                   #
        if (step % 100 == 0 and step != 0):
            wandb.log({"loss": loss.item()})
            losses.append(loss.item())

    # 8. Save the model every epoch. Use separate files per epoch #
    torch.save(model.state_dict(), f"bert8_{epoch}.sd")
#
# HINT: The tokenizer can be helpful here.
################################################################
torch.save(model.state_dict(), f"bert8_pretrained.sd")
plt.plot(losses)

In [None]:
'''
Now that we have trained the model, we can use it to predict the masked tokens.
'''

# Predict the masked tokens
def noise_inputs(inputs, mask_token_id, mlm_probability=0.15):
    inputs = deepcopy(inputs)
    labels = [-100] * len(inputs)
    masked_indices = np.random.choice(len(inputs), int(len(inputs) * mlm_probability), replace=False)
    for i in masked_indices:
        if inputs[i] not in [mask_token_id, 101, 102, 0]:
            inputs[i] = mask_token_id
            labels[i] = inputs[i]
    return inputs, labels

def noise_and_predict_tokens(query, tokenizer, model) -> str:
    with torch.no_grad():
        tokenized_input = tokenizer.encode(query)        
        tokens = tokenized_input.tokens
        print('Original:', ' '.join(tokens))
        ids = np.array(tokenized_input.ids)
        inputs, labels = noise_inputs(ids, tokenizer.token_to_id('[MASK]'))
        print('Noised: ', ' '.join([tokenizer.id_to_token(at_i) for at_i in inputs]))
        
        response, attns = model(torch.from_numpy(inputs).unsqueeze(0).to('cpu'))
        response = response.argmax(-1).squeeze(0).tolist()
        print('Guess:  ', ' '.join([tokenizer.id_to_token(at_i) for at_i in response[1:-1]]).replace(' ##', ''))

s = 'I really like the book it was great and I loved reading it.'
noise_and_predict_tokens(s, tokenizer, bert)

# Part 3.4: Save the Pre-Trained Model

At this point, save the model's parameters in its `state_dict`. See pytorch's [documentation](https://pytorch.org/tutorials/beginner/saving_loading_models.html) for some guidance here. We'll be using this pre-trained model in later notebooks so once you save it, test that you can load it in another notebook (try `BERT_Inference.ipynb` to start) before moving on. 

In [None]:
################################################################
#                     TODO: YOUR CODE HERE                     #
#
# 1. Save the BERT model to a file
torch.save(model.state_dict(), f"bert8_pretrained.sd")
#
# NOTE: Before you close this notebook, verify you can load the model and 
#       use it to predict the masked tokens in the BERT_Inference notebook.
#
################################################################



# Part 3.5 Convert the notebook to a script and submit to Great Lakes for final training

Once you have your model debugged and can verify that it works on a small dataset (manual exploration in Part 4 will help), it's time to train it on more data and for a longer time period. To do this, we'll use the Great Lakes cluster at U-M which will give you access to a GPU that will make training run ~10x faster; this means more epochs and more data in the same amount of time so you get a better model. The Homework PDF has documentation on how use Great Lakes if you haven't seen it. The course account is limited to 4 hours of wallclock time and 16GB of memory, which were tuned specific for this assignment.

Great Lakes supports interactive mode with Jupyter and running a script as a job. We **strongly** encourage the latter. To get a GPU, you'll need to submit a job to the cluster, which uses [SLURM for scheduling](https://arc.umich.edu/greatlakes/slurm-user-guide/). If you attempt to queue for an interactive job, you will have no control over when it starts, so you may end up having your notebook run for 4 hours from 3am to 7am and then it ends, at which point you have to get back in the queue. If you submit a job as a script (i.e., a .py file that runs the code in this notebook), it will run for the specified amount of time and save the BERT model without you having to interact with anything. 

SLURM and cluster scheduling is very common in some industries where there is a single cluster resource and people share it by submitting jobs to run so that no one can monopolize the system and that jobs can run in parallel. Given that Great Lakes will be useful to you in future assignments and projects, we strongly encourage you to learn how to use it effectively in this assignment.

Depending on how you're working on this file, there's a few ways to directly convert the notebook to a file if you use [Jupyter or the command line](https://mljar.com/blog/convert-jupyter-notebook-python/) or [VSCode](https://stackoverflow.com/questions/64297272/best-way-to-convert-ipynb-to-py-in-vscode). Once you convert it, you'll modify the file some to change the epochs and text file as specified in the PDF. **We also strongly recommend having your script save the model at the end of every epoch.**  That way, if your script takes longer than 4 hours and gets killed, you still have the best saved model you could get based on the amount of training you could do.