In [None]:
%load_ext nb_black

# Mini-BERT: Inspired by and adapted from [MinGPT](https://github.com/karpathy/minGPT/) and [HuggingFace BERT](https://huggingface.co/transformers/model_doc/bert.html).

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import math

## One config object to initialize everything

In [None]:
class Config:
    # One config object to initalize everything
    def __init__(
        self,
        vocab_size=30000,
        max_seq_len=256,
        hidden_size=768,
        num_layers=12,
        num_heads=12,
        drop_prob=0.1,
        token_types=2,
    ):
        # Number of words in the vocabulary
        self.vocab_size = vocab_size
        # The maximal number of tokens the model can accept as input
        self.max_seq_len = max_seq_len
        # The hidden size of the model
        self.hidden_size = hidden_size
        # The number of layers in the model
        self.num_layers = num_layers
        # The number of self-attention heads
        self.num_heads = num_heads
        # Dropout probs
        self.drop_prob = drop_prob
        # Types of tokens accepted by the transformer
        self.token_types = token_types

## Scaled dot-product attention and multi-head attention

![Multi-head attention](imgs/attention.png)

Image source: [Vaswani et al., 2017. Attention is all you need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf).

In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        # Layers for projecting the queries, keys and values
        self.key = nn.Linear(config.hidden_size, config.hidden_size)
        self.query = nn.Linear(config.hidden_size, config.hidden_size)
        self.value = nn.Linear(config.hidden_size, config.hidden_size)
        # Dropout layers
        self.attn_drop = nn.Dropout(config.drop_prob)
        self.resid_drop = nn.Dropout(config.drop_prob)
        # Projection layer
        self.proj = nn.Linear(config.hidden_size, config.hidden_size)
        self.num_heads = config.num_heads

    def transpose_for_scores(self, x):
        batch_size, seq_len, hidden_size = x.size()
        # (Batch size, Seq. len., Num. heads, Hidden size)
        x = x.view(batch_size, seq_len, self.num_heads, hidden_size // self.num_heads)
        # (Batch size, Num. heads, Seq. len., Hidden size)
        x = x.transpose(1, 2)

        return x

    def forward(self, inputs: torch.Tensor, padding_mask: torch.Tensor):
        # Obtain dimensions
        batch_size, seq_len, hidden_size = inputs.size()
        # Project queries, keys and values
        keys = self.key(inputs)
        queries = self.query(inputs)
        values = self.value(inputs)
        # Reshape and transpose to prepare for dot-product attention
        keys = self.transpose_for_scores(keys)
        queries = self.transpose_for_scores(queries)
        values = self.transpose_for_scores(values)
        # Self-attention
        # (BS, NH, SL., HS) x (BS, NH, HS, SL) -> (BS, NH, SL, SL)
        attn = queries @ keys.transpose(-2, -1)
        # Scale
        attn = attn * (1.0 / math.sqrt(keys.size(-1)))
        # Set the scores of the padding tokens to -infinity
        attn.masked_fill_(padding_mask.unsqueeze(1).unsqueeze(1), float("-inf"))
        # Normalize accross the last dimension
        attn = F.softmax(attn, dim=-1)
        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper
        attn = self.attn_drop(attn)
        # Attend
        # (BS, NH, SL, SL) x (BS, NH, SL, HS) -> (BS, NH, SL, HS)
        output = attn @ values
        # (Batch size, Seq. len., Num. heads, Hidden size)
        output = output.transpose(1, 2)
        # Concatenate all the heads one next to each other
        # (Batch size, Seq. len., Hidden size)
        output = output.contiguous().view(batch_size, seq_len, hidden_size)
        # Project
        output = self.proj(output)
        output = self.resid_drop(output)
        return output

![Transformer encoder](imgs/encoder.png)

Image source: [Vaswani et al., 2017. Attention is all you need](https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf).

In [None]:
class Layer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.hidden_size)
        self.ln2 = nn.LayerNorm(config.hidden_size)
        self.attn = MultiHeadSelfAttention(config)
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, 4 * config.hidden_size),
            nn.GELU(),
            nn.Linear(4 * config.hidden_size, config.hidden_size),
            nn.Dropout(config.drop_prob),
        )

    def forward(self, inputs: torch.Tensor, padding_mask: torch.Tensor):
        # Multi-head self-attention + Add & Norm
        inputs = self.ln1(inputs + self.attn(inputs, padding_mask))
        # Feed-forward + Add & norm
        inputs = self.ln2(inputs + self.mlp(inputs))
        return inputs

## BERT embeddings

![Bert embeddings](imgs/bert_embeddings.png)

Image source: [Devlin et al., 2018. BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/abs/1810.04805).

In [None]:
class BertEmbeddings(nn.Module):
    # Adapted from:
    # https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_bert.py#L165
    def __init__(self, config):
        super().__init__()
        # Yellow
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        # Grey
        self.position_embeddings = nn.Embedding(config.max_seq_len, config.hidden_size)
        # Green
        self.token_type_embeddings = nn.Embedding(
            config.token_types, config.hidden_size
        )
        # Not included: LayerNorm and Dropout
        self.ln = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.drop_prob)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
    ):
        # (Batch size, Seq. len., Hidden size)
        inputs_embeds = self.word_embeddings(input_ids)
        pos_embeds = self.position_embeddings(position_ids)
        token_type_embeds = self.token_type_embeddings(token_type_ids)
        # Summing all three together
        embeddings = inputs_embeds + pos_embeds + token_type_embeds
        # Normalize and dropout
        embeddings = self.ln(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

## Defining the complete model: BERT

In [None]:
class BERT(nn.Module):
    def __init__(self, config):
        super().__init__()
        # BERT embeddings
        self.embeddings = BertEmbeddings(config)
        # Transformer encoder
        self.encoder = nn.ModuleList([Layer(config) for _ in range(config.num_layers)])

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        padding_mask: torch.Tensor,
    ):
        # Obtain the embeddings
        hidden_states = self.embeddings(input_ids, position_ids, token_type_ids)
        # Pass the embeddings through the Transformer encoder
        for layer in self.encoder:
            hidden_states = layer(hidden_states, padding_mask)

        return hidden_states

## Testing everything with dummy inputs

In [None]:
config = Config()
bert = BERT(config)

In [None]:
with torch.no_grad():
    input_ids = torch.tensor([[1, 2, 3, 4], [8, 8, 10, 15]], dtype=torch.long)
    position_ids = torch.arange(4, dtype=torch.long).unsqueeze(0).repeat(2, 1)
    token_ids = torch.ones(2, 4, dtype=torch.long)
    padding_mask = torch.tensor(
        [[False, False, False, True], [False, False, True, True]]
    )
    outputs = bert(input_ids, position_ids, token_ids, padding_mask)