In [None]:
# | default_exp project.models.bert.model

# BERT

![](/Users/mcaro/workbench-ai/dman-nebula-micromachines-build/notebooks/models/bert.png)

In [None]:
# | export

import numpy as np

# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F

## Embedding

The code below defines a PyTorch module called `Embedding` that represents an embedding layer. The module takes in four arguments: 
- `max_length`, which is the maximum length of the input sequence; 
- `n_segments`, which is the number of segments in the input sequence; 
- `vocab_size`, which is the size of the vocabulary; and 
- `d_model`, which is the size of the embedding vector.

The `__init__` method initializes the embedding layer and a layer normalization module. The embedding layer is created using three `nn.Embedding` modules: one for the token embeddings, one for the positional embeddings, and one for the segment embeddings. The positional embeddings are generated using the `torch.arange` function to create a tensor of integers from 0 to `seq_len-1`, where `seq_len` is the length of the input sequence. The segment embeddings are created using the `n_segments` argument. The layer normalization module is created using the `nn.LayerNorm` module.

The `forward` method takes in two input tensors: 
- `x`, which is the input sequence, and 
- `seg`, which is the segment IDs of the input sequence. 
 
The method first calculates the sequence length of the input tensor `x`. It then generates the positional embeddings using the `torch.arange` function and combines the token, positional, and segment embeddings using element-wise addition. Finally, the method applies layer normalization to the resulting tensor and returns it.

Overall, this code defines an embedding layer that can be used in a neural network for natural language processing tasks such as text classification and language modeling. The layer takes in an input sequence and generates embeddings for each token in the sequence, taking into account the position and segment of each token. The layer normalization step helps to stabilize the training process and improve the performance of the model.

In [None]:
# | export


class Embedding(nn.Module):
    def __init__(
        self,
        max_length,
        vocab_size,
        d_model,
    ):
        super(
            Embedding,
            self,
        ).__init__()
        # self.tok_embed.shape = (vocab_size, d_model)
        self.tok_embed = nn.Embedding(
            vocab_size,
            d_model,
        )  # token embedding
        self.pos_embed = nn.Embedding(
            max_length,
            d_model,
        )  # position embedding
        self.norm = nn.LayerNorm(d_model)  # layer normalization

    def forward(
        self,
        x,
    ):
        """
        x: (batch_size, seq_len)
        """
        seq_len = x.size(1)
        pos = torch.arange(
            seq_len,
            dtype=torch.long,
        ).cuda()  # 0, 1, 2,... , seq_len-1
        pos = pos.unsqueeze(0).expand_as(x)  # (seq_len,) -> (1, seq_len) -> (batch_size, seq_len)
        embedding = self.tok_embed(x) + self.pos_embed(pos)  # + self.seg_embed(seg) # (batch_size, seq_len, d_model)
        return self.norm(embedding)  # (batch_size, seq_len, d_model)

![Alt text](/Users/mcaro/workbench-ai/micro-machines-research/nbs/models/image.png)

### Attention Mask

In the BERT model, the attention mask is used to mask out padding tokens in the input sequence during the attention mechanism. The attention mask is a tensor of the same shape as the attention scores tensor, where each element is either 0 or -inf. The attention scores tensor is computed as the dot product of the query, key, and value tensors, and represents the importance of each token in the input sequence for each output token.

In [None]:
# | export


def get_attn_pad_mask(
    seq_q,
):
    """For masking out the padding part of key sequence.
    This ensures that the attention mechanism does not attend to padding tokens, which do not contain any useful information.
    """
    (
        batch_size,
        len_q,
    ) = seq_q.size()  # (batch_size, len_q)
    # eq(zero) is PAD token
    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)  # (batch_size, 1, len_q), True is masked
    return pad_attn_mask.expand(
        batch_size,
        len_q,
        len_q,
    )  # (batch_size, len_q, len_q)

The attention mask is created using the `get_attn_pad_mask` function, which takes in the input sequence tensor and returns a tensor of the same shape as the attention scores tensor. The attention mask tensor has a value of -inf for padding tokens and a value of 0 for non-padding tokens. This ensures that the attention mechanism does not attend to padding tokens, which do not contain any useful information.

In addition to the padding mask, the BERT model also uses a second attention mask called the "look-ahead" mask. This mask is used to prevent tokens from attending to future tokens in the input sequence during the attention mechanism. The look-ahead mask is a lower-triangular matrix with -inf values in the upper triangle and 0 values in the lower triangle. This ensures that each token can only attend to previous tokens in the input sequence.

The attention mask and look-ahead mask are combined using element-wise addition to create the final attention mask used in the BERT model. The resulting attention mask is a tensor of the same shape as the attention scores tensor, with -inf values for padding tokens and future tokens, and 0 values for non-padding tokens and previous tokens.

The `get_attn_pad_mask` function is a utility function used in the implementation of the BERT model. 

It takes: 

- `seq_q` in an input tensor of shape `(batch_size, seq_len_q)` and 
- `seq_k` an input tensor of shape `(batch_size, seq_len_k)` 
  
and returns a tensor of shape `(batch_size, seq_len_q, seq_len_k)` that can be used as a mask for the attention mechanism.


The function first creates a tensor `pad_attn_mask` of shape `(batch_size, seq_len_q, seq_len_k)` filled with zeros. It then creates a boolean tensor `pad_mask` of shape `(batch_size, seq_len_k)` that is `True` for padding tokens and `False` for non-padding tokens. This is done by checking if the input tensor `seq_k` is equal to a special padding token.

The function then loops over the batch dimension of the input tensor `seq_q` and sets the corresponding row of `pad_attn_mask` to `True` for all columns corresponding to padding tokens in `seq_k`. This is done using the boolean tensor `pad_mask` and the `unsqueeze` method to add a new dimension to the tensor.

Finally, the function returns the tensor `pad_attn_mask`. This tensor can be used as a mask for the attention mechanism to ensure that attention is not paid to padding tokens in the input sequence.

In [None]:
# Example usage
seq_q = torch.tensor(
    [
        [
            1,
            2,
            3,
        ],
        [
            4,
            5,
            0,
        ],
    ]
)
pad_attn_mask = get_attn_pad_mask(seq_q)
print(pad_attn_mask)

# Encoder

## Scaled Dot Product Attention

In [None]:
# | export


class ScaledDotProductAttention(nn.Module):
    def __init__(
        self,
        d_k,
    ):
        super(
            ScaledDotProductAttention,
            self,
        ).__init__()
        self.d_k = d_k

    def forward(
        self,
        Q,
        K,
        V,
        attn_mask,
    ):
        """
        Q, K, V: (batch_size, seq_len, d_model)
        attn_mask: (batch_size, seq_len, seq_len)
        """
        scores = torch.matmul(
            Q,
            K.transpose(
                -1,
                -2,
            ),
        ) / np.sqrt(
            self.d_k
        )  # scores: (batch_size, seq_len, seq_len)
        scores.masked_fill_(
            attn_mask,
            float("-inf"),
        )  # Fills elements of self tensor with value where mask is True.
        attn = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(
            attn,
            V,
        )  # (batch_size, seq_len, d_model)
        return (
            scores,
            context,
            attn,
        )

In [None]:
# example

emb = Embedding(
    6,
    10,
    5,
)  # max_length, vocab_size, d_model

inputs = torch.tensor(
    [
        [
            1,
            2,
            0,
            0,
            0,
            0,
        ],
        [
            1,
            2,
            3,
            4,
            0,
            0,
        ],
    ]
)  # (batch_size=2, seq_len=10)

embeds = emb(inputs)  # INPUTS, SEGMENT

AttMask = get_attn_pad_mask(inputs)  # (batch_size, seq_len, seq_len)

SDPA = ScaledDotProductAttention(d_k=4)

(
    Score,
    Contex,
    Attention,
) = SDPA(
    embeds,
    embeds,
    embeds,
    AttMask,
)  # Q, K, V, attn_mask

print(
    "Mask: ",
    AttMask,
)
# print('Score: ', Score)
print(
    "Attention: ",
    Attention,
)

## Multi-Head Attention

In [None]:
# | export


class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        d_k,
        d_model,
        num_heads,
    ):
        super(
            MultiHeadAttention,
            self,
        ).__init__()
        self.d_k = d_k
        self.num_heads = num_heads
        self.d_model = d_model
        self.M_Q = nn.Linear(
            d_model,
            d_k * num_heads,
        )
        self.M_K = nn.Linear(
            d_model,
            d_k * num_heads,
        )
        self.M_V = nn.Linear(
            d_model,
            d_k * num_heads,
        )
        self.scaled_dot_product_attention = ScaledDotProductAttention(d_k)
        self.output_linear = nn.Linear(
            num_heads * d_k,
            d_model,
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(
        self,
        Q,
        K,
        V,
        attn_mask,
    ):
        """
        Q, K, V: (batch_size, seq_len, d_model)
        attn_mask: (batch_size, seq_len, seq_len)
        """
        (
            residual,
            batch_size,
        ) = (
            Q,
            Q.size(0),
        )
        q_s = self.M_Q(Q)  # q_s: (batch_size, seq_len, d_k * num_heads)
        q_s = q_s.view(
            batch_size,
            -1,
            self.num_heads,
            self.d_k,
        )  # q_s: (batch_size, seq_len, num_heads, d_k)
        q_s = q_s.transpose(
            1,
            2,
        )  # q_s: (batch_size, num_heads, seq_len, d_k)

        k_s = self.M_Q(K).view(
            batch_size,
            -1,
            self.num_heads,
            self.d_k,
        )  # k_s: (batch_size, seq_len, num_heads, d_k)
        k_s = k_s.transpose(
            1,
            2,
        )  # k_s: (batch_size, num_heads, seq_len, d_k)

        v_s = self.M_Q(V).view(
            batch_size,
            -1,
            self.num_heads,
            self.d_k,
        )  # v_s: (batch_size, seq_len, num_heads, d_k)
        v_s = v_s.transpose(
            1,
            2,
        )  # v_s: (batch_size, num_heads, seq_len, d_k)

        attn_mask = attn_mask.unsqueeze(1)  # attn_mask: (batch_size, 1, seq_len, seq_len)
        attn_mask = attn_mask.repeat(
            1,
            self.num_heads,
            1,
            1,
        )  # attn_mask: (batch_size, num_heads, seq_len, seq_len)

        (
            scores,
            context,
            attn,
        ) = self.scaled_dot_product_attention(
            q_s,
            k_s,
            v_s,
            attn_mask,
        )  # context: (batch_size, num_heads, seq_len, d_k)
        context = (
            context.transpose(
                1,
                2,
            )
            .contiguous()
            .view(
                batch_size,
                -1,
                self.num_heads * self.d_k,
            )
        )  # context: (batch_size, seq_len, num_heads * d_k)

        output = self.output_linear(context)  # output: (batch_size, seq_len, d_model)

        return (
            self.norm(output + residual),
            attn,
        )  # output: (batch_size, seq_len, d_model), attn: (batch_size, num_heads, seq_len, seq_len)


class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."

    def __init__(
        self,
        d_model,
        d_ff,
        dropout=0.1,
    ):
        super(
            PositionwiseFeedForward,
            self,
        ).__init__()
        self.w_1 = nn.Linear(
            d_model,
            d_ff,
        )
        self.w_2 = nn.Linear(
            d_ff,
            d_model,
        )
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x,
    ):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

In [None]:
# | export


class EncoderLayer(nn.Module):
    def __init__(
        self,
        d_model=1024,
        d_k=64,
        num_heads=12,
    ):
        super(
            EncoderLayer,
            self,
        ).__init__()
        self.enc_self_attn = MultiHeadAttention(
            d_k=d_k,
            d_model=d_model,
            num_heads=num_heads,
        )
        self.feed_forward = PositionwiseFeedForward(
            d_model,
            d_ff=d_model * 4,
        )

    def forward(
        self,
        enc_input,
        enc_self_attn_mask,
    ):
        (
            enc_output,
            attn,
        ) = self.enc_self_attn(
            Q=enc_input,
            K=enc_input,
            V=enc_input,
            attn_mask=enc_self_attn_mask,
        )  # enc_input to same Q,K,V
        enc_output = self.feed_forward(enc_output)
        return (
            enc_output,
            attn,
        )

# BERT

In [None]:
# | export
import math


def gelu(
    x,
):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


class BERT(nn.Module):
    def __init__(
        self,
        max_length,
        vocab_size,
        d_k=64,
        d_model=1024,
        num_heads=12,
        n_layers=6,
    ):
        super(
            BERT,
            self,
        ).__init__()
        self.embedding = Embedding(
            max_length,
            vocab_size,
            d_model,
        )
        self.layers = nn.ModuleList(
            [
                EncoderLayer(
                    d_model,
                    d_k,
                    num_heads,
                )
                for _ in range(n_layers)
            ]
        )
        self.fc = nn.Linear(
            d_model,
            d_model,
        )
        self.activ1 = nn.Tanh()
        self.linear = nn.Linear(
            d_model,
            d_model,
        )
        self.activ2 = gelu
        self.norm = nn.LayerNorm(d_model)

        embed_weight = self.embedding.tok_embed.weight

        (
            n_vocab,
            n_dim,
        ) = embed_weight.size()
        self.decoder = nn.Linear(
            n_dim,
            n_vocab,
            bias=False,
        )
        self.decoder.weight = embed_weight
        self.decoder.bias = nn.Parameter(torch.zeros(n_vocab))

    def forward(
        self,
        inputs,
        masked_pos,
    ):
        output = self.embedding(inputs)  # (batch_size, seq_len, d_model)
        enc_self_attn_mask = get_attn_pad_mask(inputs).cuda()  # (batch_size, seq_len, seq_len)
        for layer in self.layers:
            (
                output,
                enc_self_attn,
            ) = layer(
                output,
                enc_self_attn_mask,
            )  # output: (batch_size, seq_len, d_model)

        # The tensor `masked_pos`  indicates the positions of the masked tokens in the input sequence.
        # The tensor `output` is the output of the BERT model, which is a tensor of shape `(batch_size, seq_len, d_model)`.
        masked_pos = masked_pos[
            :,
            :,
            None,
        ].expand(
            -1,
            -1,
            output.size(-1),
        )  # (batch_size, seq_len, d_model)

        h_masked = torch.gather(
            output,
            1,
            masked_pos,
        )  # (batch_size, 1, d_model), gather the masked tokens
        h_masked = self.norm(self.activ2(self.linear(h_masked)))  # h_masked: (batch_size, 1, d_model) -> Linear(d_model, d_model) -> GELU -> LayerNorm -> (batch_size, 1, d_model)
        logist_lm = self.decoder(h_masked) + self.decoder.bias  # (batch_size, 1, n_vocab)
        return logist_lm

## Get the masked token

In the context of the BERT model, `torch.gather(output, 1, masked_pos)` is used to extract the hidden states of the masked tokens from the output of the BERT model.

The tensor `output` is the output of the BERT model, which is a tensor of shape `(batch_size, seq_len, d_model)`. The tensor `masked_pos` is a tensor of shape `(batch_size, seq_len, 1)` that indicates the positions of the masked tokens in the input sequence.

The `torch.gather` function is used to extract the hidden states of the masked tokens from the tensor `output`. The function takes three arguments: the input tensor, the dimension along which to index the tensor, and the index tensor.

In this case, the `output` tensor is indexed along the second dimension (i.e., `dim=1`) using the `masked_pos` tensor as the index tensor. The resulting tensor has shape `(batch_size, 1, d_model)` and contains the hidden states of the masked tokens.

Here's an example of how to use `torch.gather`:



## Decoder

In the context of the BERT model, these lines of code define the output layer of the model, which maps the hidden states of the input tokens to a probability distribution over the vocabulary.

The tensor `embed_weight` is a pre-trained embedding matrix that maps each token in the vocabulary to a dense vector representation. The tensor has shape `(vocab_size, embedding_size)`, where `vocab_size` is the size of the vocabulary and `embedding_size` is the size of the embedding vector.

The code `n_vocab, n_dim = embed_weight.size()` extracts the size of the vocabulary and the size of the embedding vector from the `embed_weight` tensor.

The code `self.decoder = nn.Linear(n_dim, n_vocab, bias=False)` defines a linear layer that maps the hidden states of the input tokens to a vector of size `n_vocab`. The linear layer has no bias term (`bias=False`) and uses the pre-trained embedding matrix `embed_weight` as its weight matrix.

The code `self.decoder.weight = embed_weight` sets the weight matrix of the linear layer to the pre-trained embedding matrix `embed_weight`.

The code `self.decoder.bias = nn.Parameter(torch.zeros(n_vocab))` sets the bias term of the linear layer to a tensor of zeros with shape `(n_vocab,)`. The bias term is represented as a `nn.Parameter` object, which allows it to be learned during training.

Together, these lines of code define the output layer of the BERT model, which maps the hidden states of the input tokens to a probability distribution over the vocabulary. The output layer is trained to predict the next token in a sequence during the masked language modeling task.