# Replicate the Transformer architecture with PyTorch + Lightning

In this notebook, we're going to replicate the Transformer architecture in the paper [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762) with PyTorch and Lightning. The goal is to train a transformer model to translate english sentences to spanish sentences (**machine translation**).

In [271]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

## Setup device-agnostic code

In [272]:
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
device

'mps'

## Create tokenizers

We need to create 2 separate tokenizers for both the source language (English in our case) and the target language (Spanish).

In [273]:
# Create a tokenizer for the source language (English)
token_to_id_src = {
    "<SOS>": 0,
    "i": 1,
    "love": 2,
    "you": 3,
    "me": 4,
    "<EOS>": 5
}

# Create a tokenizer for the target language (Spanish)
token_to_id_target = {
    "<SOS>": 0, # tells the decoder to start generating tokens
    "te": 1,
    "amo": 2,
    "me": 3,
    "amas": 4,
    "<EOS>": 5 # tells the decoder to stop generating tokens
}

# Create a mapping from tokens to numbers to interpret the output from the transformer later
id_to_token_src = dict(map(reversed, token_to_id_src.items())) # not necessary here
id_to_token_target = dict(map(reversed, token_to_id_target.items()))
id_to_token_src, id_to_token_target

({0: '<SOS>', 1: 'i', 2: 'love', 3: 'you', 4: 'me', 5: '<EOS>'},
 {0: '<SOS>', 1: 'te', 2: 'amo', 3: 'me', 4: 'amas', 5: '<EOS>'})

## Create a training dataset

In [274]:
inputs = torch.tensor([
    [
        token_to_id_src["<SOS>"],
        token_to_id_src["i"],
        token_to_id_src["love"],
        token_to_id_src["you"],
        token_to_id_src["<EOS>"]
    ],
    
    [
        token_to_id_src["<SOS>"],
        token_to_id_src["you"],
        token_to_id_src["love"],
        token_to_id_src["me"],
        token_to_id_src["<EOS>"],
    ]
])

labels = torch.tensor([
    [
        token_to_id_target["<SOS>"],
        token_to_id_target["te"],
        token_to_id_target["amo"],
        token_to_id_target["<EOS>"]
    ],

    [
        token_to_id_target["<SOS>"],
        token_to_id_target["me"],
        token_to_id_target["amas"],
        token_to_id_target["<EOS>"]
    ]
])
inputs, labels

(tensor([[0, 1, 2, 3, 5],
         [0, 3, 2, 4, 5]]),
 tensor([[0, 1, 2, 5],
         [0, 3, 4, 5]]))

## Review of the transformer architecture

Let's break down the transformer architecture in smaller components:
1. Tokenizer: map the text to numbers.
2. Embedding layer: map the tokens to embeddings.
3. Positional Encoding: add the positional information to the embeddings.
4. Multi-Head Attention block: the core part of the transformer architecture, it computes the attention values each token will give to every other token in the sequence (including itself). For simplicity, we'll only implement self-attention in this notebook. </br>
    **Note:** There are two different types of attention used in the paper:
    - (Standard) Self-Attention: Bidirectional, used by the encoder
    - Masked Self-Attention: Unidirectional, tokens that come after the query token are masked, used by the decoder
    - Encoder-Decoder Attention: the decoder can attend to any position to the encoder's output
5. MLP block

### Figure 1
<img src="transformer_architecture.png" width=300 alt="figure 1 from transformer paper"/>

## Create position encoding

We will use the formula used in the paper *Attention is all you need* position encoding is:
* PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))
* PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))

In [275]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=6, max_len=6):

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        div_term = torch.tensor(10000)**(embedding_index / d_model)

        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer("pe", pe)
    
    def forward(self, word_embeddings):

        return word_embeddings + self.pe[:word_embeddings.size(0), :]

## Create Attention layers

In [276]:
class AttentionBlock(nn.Module):
    
    def __init__(self, d_model=6):

        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None, device=device):

        # Create the Q, K and V matrices
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Calculate the similarity score between the queries and values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)

        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat during training
        if mask is not None:
            mask = mask.to(device) # move your mask to the target device because a manually created tensor lives in the cpu by default
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)  # -1e9 is an approximation of negative infinity

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # attention_scores are the contextualised embeddings
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## Create MLP layer

In [277]:
class MlpBlock(nn.Module):
    """
    Replicating the MLP block used in the transformer paper. It contains a linear layer and a "add & norm" (residual connection and layer norm) layer.
    """
    def __init__(self, d_model: int = 2):
        
        super().__init__()

        self.linear_layer = nn.Linear(in_features=d_model,
                                      out_features=d_model)
        
        self.layer_norm = nn.LayerNorm(normalized_shape=d_model)

    def forward(self, x):

        # Residual connection
        x = self.layer_norm(x + self.linear_layer(x)) # operation fusion gives faster performance
        return x

## Create the Encoder

Workflow:
1. Tokenize the source text
2. Pass the tokens through the embedding layer
3. Add positional encodings
4. Pass the embeddings through the attention layer
5. Pass the contextualised embeddings through the mlp block
6. Return the results

In [278]:
class EncoderBlock(nn.Module):
    """
    Replicating the encoder block in the transformer paper. It returns the contextualised embeddings of the source text tokens.
    """
    def __init__(self, 
                 num_tokens: int, 
                 d_model: int = 2, 
                 max_len: int = 6,
                 device: torch.device = device):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Self-Attention
        self.attention = AttentionBlock(d_model=d_model)

        # MLP block
        self.mlp_block = MlpBlock(d_model=d_model)

    def forward(self, token_ids):

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Create Self-Attention layers
        self_attention_values = self.attention(position_encoded,
                                               position_encoded,
                                               position_encoded,
                                               mask=None) # no mask is needed
        
        # Add residual connections and normalise
        mlp_block_out = self.mlp_block(self_attention_values)

        return mlp_block_out

In [279]:
# Test the encoder
encoder = EncoderBlock(num_tokens=len(token_to_id_src),
                       d_model=6,
                       max_len=6)

# Forward pass
encoder_output = encoder(inputs[0])
encoder_output

tensor([[-1.2405,  0.0951,  1.3264, -1.0392, -0.3596,  1.2178],
        [-1.3465,  0.0830,  1.3067, -0.9236, -0.3445,  1.2249],
        [-1.2227,  0.0731,  1.4018, -0.8164, -0.6422,  1.2064],
        [-0.1274, -0.0895,  1.1554,  0.6726, -2.0077,  0.3966],
        [-0.5583, -0.0343,  1.4319,  0.1238, -1.7379,  0.7749]],
       grad_fn=<NativeLayerNormBackward0>)

## Create the Decoder

Workflow:
1. Tokenize the target text
2. Pass the tokens through the embedding layer
3. Add positional encodings
4. Pass them through the masked attention layer
5. Pass them through the mlp block
6. Pass them through the encoder-decoder attention layer
7. Pass them through the mlp block again
8. Return the result

In [280]:
class DecoderBlock(nn.Module):
    """
    Replicating the decoder block in the transformer paper. It returns the contextualised embeddings of the target text tokens.
    """
    def __init__(self,
                 num_tokens: int, 
                 d_model: int = 2, 
                 max_len: int = 6,
                 device: torch.device = device):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Attention block
        self.attention = AttentionBlock(d_model=d_model)

        # MLP block
        self.mlp_block = MlpBlock(d_model=d_model)

        self.device = device

    def forward(self, 
                token_ids,
                encoder_embeddings):

        # Create word embeddings
        word_embeddings = self.we(token_ids)

        # Add position encodings to the word embeddings
        decoder_embedding = self.pe(word_embeddings)

        # Create a mask matrix for masking used in masked self attention
        mask = torch.tril(torch.ones(token_ids.size(dim=0),token_ids.size(dim=0))) == 0 # the shape of mask is: [seq_len, seq_len]

        # Create Masked-Attention layer
        decoder_attention_values = self.attention(decoder_embedding,
                                                          decoder_embedding,
                                                          decoder_embedding,
                                                          mask=mask,
                                                          device=self.device)
        
        # Add residual connections and normalise
        mlp_block_out1 = self.mlp_block(decoder_attention_values)

        # Create Encoder-Decoder-Attention layer
        encoder_decoder_attention_values = self.attention(mlp_block_out1, # decoder's queries
                                                          encoder_embeddings, # encoder's keys
                                                          encoder_embeddings, # encoder's values
                                                          mask=None) # mask is not needed
        
        # Add residual connections and normalise
        mlp_block_out2 = self.mlp_block(encoder_decoder_attention_values)

        return mlp_block_out2

In [281]:
# Test the decoder
decoder = DecoderBlock(num_tokens=len(token_to_id_target),
                       d_model=6,
                       max_len=6,
                       device="cpu")
decoder(token_ids=labels[0],
        encoder_embeddings=encoder_output)

tensor([[ 0.1349, -0.4699,  0.0315, -1.7874,  1.4807,  0.6102],
        [ 0.1382, -0.4721,  0.0276, -1.7883,  1.4746,  0.6200],
        [ 0.1535, -0.4826,  0.0079, -1.7921,  1.4431,  0.6702],
        [ 0.1941, -0.5069, -0.0429, -1.7965,  1.3586,  0.7937]],
       grad_fn=<NativeLayerNormBackward0>)

## Create the Transformer

Now it's time to put all the puzzles together

Workflow:
1. Pass the tokenized sequence in source language through the encoder block
2. Pass the tokenized sequence in target language through the decoder block
3. Pass the result from the decoder through the last linear layer
4. Run the resullt through a softmax layer
5. Return the result

In [294]:
class Transformer(L.LightningModule):
    """Replicate the transformer architecture from the paper."""
    def __init__(self,
                 num_tokens_src: int, # number of vocabs in the source language
                 num_tokens_target: int, # number of vocabs in the target language
                 d_model: int = 2,
                 max_len: int = 6,
                 device: torch.device = device):
        super().__init__()

        # Create an encoder block
        self.encoder_block = EncoderBlock(
            num_tokens=num_tokens_src,
            d_model=d_model,
            max_len=max_len,
            device=device
        )

        # Create a decoder block
        self.decoder_block = DecoderBlock(
            num_tokens=num_tokens_target,
            d_model=d_model,
            max_len=max_len,
            device=device
        )

        # Create the final linear layer
        self.final_linear_layer = nn.Linear(in_features=d_model,
                                            out_features=num_tokens_target) # output the probabilities for each vocab in the target tokenizer
        
        # Setup loss function
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, token_ids_src, token_ids_target): # the target sequence is also needed during training for teacher forcing
        
        # Pass the source tokens to the encoder block
        contextualised_embeddings_encoder = self.encoder_block(token_ids_src)

        # Pass the target sequence to the decoder block
        contextualised_embeddings_decoder = self.decoder_block(token_ids=token_ids_target,
                                                               encoder_embeddings=contextualised_embeddings_encoder)
        
        # Pass the result from decoder to the linear layer and then run it through a softmax layer
        return self.final_linear_layer(contextualised_embeddings_decoder)
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        input_tokens, labels = batch
        # print(f"input shape: {input_tokens.shape}, labels shape: {labels.shape}")
        outputs = self.forward(token_ids_src=input_tokens[0],
                               token_ids_target=labels[0])
        loss = self.loss(outputs, labels[0])

        return loss

In [301]:
# Test the transformer
model = Transformer(num_tokens_src=len(token_to_id_src),
                    num_tokens_target=len(token_to_id_target),
                    d_model=6,
                    max_len=6,
                    device=device).to(device)
model(token_ids_src=inputs[0].to(device),
      token_ids_target=labels[0].to(device))

tensor([[ 0.2121,  0.0463, -0.5916, -0.0649, -0.1992, -0.1099],
        [ 0.2146,  0.0508, -0.5983, -0.0682, -0.1910, -0.1035],
        [ 0.2152,  0.0506, -0.5924, -0.0672, -0.1980, -0.1032],
        [ 0.2148,  0.0473, -0.5745, -0.0625, -0.2190, -0.1069]],
       device='mps:0', grad_fn=<LinearBackward0>)

## Train the model with `Lightning.Trainer.fit()`

In [296]:
### Create dataloader
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)
len(dataloader)

2

In [302]:
# Create an instance of trainer
trainer = L.Trainer(max_epochs=60,
                    accelerator="auto")
trainer.fit(model,
            train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name               | Type             | Params | Mode 
----------------------------------------------------------------
0 | encoder_block      | EncoderBlock     | 198    | train
1 | decoder_block      | DecoderBlock     | 198    | train
2 | final_linear_layer | Linear           | 42     | train
3 | loss               | CrossEntropyLoss | 0      | train
----------------------------------------------------------------
438       Trainable params
0         Non-trainable params
438       Total params
0.002     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=60` reached.


In [303]:
# Test the transformer
model.to(device)
model(token_ids_src=inputs[0].to(device),
      token_ids_target=labels[0].to(device))

tensor([[ 0.4755, -0.2123, -0.2173, -0.2164, -0.2187,  0.4778],
        [ 0.4755, -0.2123, -0.2173, -0.2164, -0.2187,  0.4778],
        [ 0.4755, -0.2123, -0.2173, -0.2164, -0.2187,  0.4778],
        [ 0.4755, -0.2123, -0.2173, -0.2164, -0.2187,  0.4778]],
       device='mps:0', grad_fn=<LinearBackward0>)