# Replicate the Transformer architecture with PyTorch + Lightning

In this notebook, we're going to replicate the Transformer architecture in the paper [*Attention Is All You Need*](https://arxiv.org/abs/1706.03762) with PyTorch and Lightning. The goal is to train a transformer model to translate english sentences to spanish sentences (**machine translation**).

In [179]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader

import lightning as L

## Setup device-agnostic code

In [180]:
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
device

'mps'

## Create tokenizers

We need to create 2 separate tokenizers for both the source language (English in our case) and the target language (Spanish).

In [181]:
# Create a tokenizer for the source language (English)
token_to_id_src = {
    "<SOS>": 0,
    "i": 1,
    "love": 2,
    "you": 3,
    "me": 4,
    "see": 5,
    "like": 6,
    "pizza": 7,
    "cake": 8,
    "eat": 9,
    "<EOS>": 10,
    "<PAD>": 11
}

# Create a tokenizer for the target language (Spanish)
token_to_id_target = {
    "<SOS>": 0,
    "te": 1,
    "amo": 2,
    "me": 3,
    "amas": 4,
    "veo": 5,
    "ves": 6,
    "gusta": 7,
    "pizza": 8,
    "pastel": 9,
    "yo": 10,
    "como": 11,
    "tú": 12,
    "comes": 13,
    "<EOS>": 14,
    "<PAD>": 15
}

# Create a mapping from tokens to numbers to interpret the output from the transformer later
id_to_token_src = dict(map(reversed, token_to_id_src.items())) # not necessary here
id_to_token_target = dict(map(reversed, token_to_id_target.items()))
id_to_token_src, id_to_token_target

({0: '<SOS>',
  1: 'i',
  2: 'love',
  3: 'you',
  4: 'me',
  5: 'see',
  6: 'like',
  7: 'pizza',
  8: 'cake',
  9: 'eat',
  10: '<EOS>',
  11: '<PAD>'},
 {0: '<SOS>',
  1: 'te',
  2: 'amo',
  3: 'me',
  4: 'amas',
  5: 'veo',
  6: 'ves',
  7: 'gusta',
  8: 'pizza',
  9: 'pastel',
  10: 'yo',
  11: 'como',
  12: 'tú',
  13: 'comes',
  14: '<EOS>',
  15: '<PAD>'})

## Create a training dataset

In [182]:
inputs = torch.tensor([
    [token_to_id_src["<SOS>"], token_to_id_src["i"],   token_to_id_src["love"], token_to_id_src["you"],   token_to_id_src["<EOS>"]],
    [token_to_id_src["<SOS>"], token_to_id_src["you"], token_to_id_src["love"], token_to_id_src["me"],    token_to_id_src["<EOS>"]],
    [token_to_id_src["<SOS>"], token_to_id_src["i"],   token_to_id_src["see"],  token_to_id_src["you"],   token_to_id_src["<EOS>"]],
    [token_to_id_src["<SOS>"], token_to_id_src["you"], token_to_id_src["see"],  token_to_id_src["me"],    token_to_id_src["<EOS>"]],
    [token_to_id_src["<SOS>"], token_to_id_src["i"],   token_to_id_src["like"], token_to_id_src["pizza"], token_to_id_src["<EOS>"]],
    [token_to_id_src["<SOS>"], token_to_id_src["you"], token_to_id_src["like"], token_to_id_src["cake"],  token_to_id_src["<EOS>"]],
    [token_to_id_src["<SOS>"], token_to_id_src["i"],   token_to_id_src["eat"],  token_to_id_src["pizza"], token_to_id_src["<EOS>"]],
    [token_to_id_src["<SOS>"], token_to_id_src["you"], token_to_id_src["eat"],  token_to_id_src["cake"],  token_to_id_src["<EOS>"]],
])

labels = torch.tensor([
    [token_to_id_target["<SOS>"], token_to_id_target["te"],  token_to_id_target["amo"],   token_to_id_target["<EOS>"], token_to_id_target["<PAD>"]],
    [token_to_id_target["<SOS>"], token_to_id_target["me"],  token_to_id_target["amas"],  token_to_id_target["<EOS>"], token_to_id_target["<PAD>"]],
    [token_to_id_target["<SOS>"], token_to_id_target["te"],  token_to_id_target["veo"],   token_to_id_target["<EOS>"], token_to_id_target["<PAD>"]],
    [token_to_id_target["<SOS>"], token_to_id_target["me"],  token_to_id_target["ves"],   token_to_id_target["<EOS>"], token_to_id_target["<PAD>"]],
    [token_to_id_target["<SOS>"], token_to_id_target["me"],  token_to_id_target["gusta"], token_to_id_target["pizza"],  token_to_id_target["<EOS>"]],
    [token_to_id_target["<SOS>"], token_to_id_target["te"],  token_to_id_target["gusta"], token_to_id_target["pastel"], token_to_id_target["<EOS>"]],
    [token_to_id_target["<SOS>"], token_to_id_target["yo"],  token_to_id_target["como"],  token_to_id_target["pizza"],  token_to_id_target["<EOS>"]],
    [token_to_id_target["<SOS>"], token_to_id_target["tú"],  token_to_id_target["comes"], token_to_id_target["pastel"], token_to_id_target["<EOS>"]],
])
inputs, labels

(tensor([[ 0,  1,  2,  3, 10],
         [ 0,  3,  2,  4, 10],
         [ 0,  1,  5,  3, 10],
         [ 0,  3,  5,  4, 10],
         [ 0,  1,  6,  7, 10],
         [ 0,  3,  6,  8, 10],
         [ 0,  1,  9,  7, 10],
         [ 0,  3,  9,  8, 10]]),
 tensor([[ 0,  1,  2, 14, 15],
         [ 0,  3,  4, 14, 15],
         [ 0,  1,  5, 14, 15],
         [ 0,  3,  6, 14, 15],
         [ 0,  3,  7,  8, 14],
         [ 0,  1,  7,  9, 14],
         [ 0, 10, 11,  8, 14],
         [ 0, 12, 13,  9, 14]]))

## Review of the transformer architecture

Let's break down the transformer architecture in smaller components:
1. Tokenizer: map the text to numbers.
2. Embedding layer: map the tokens to embeddings.
3. Positional encoding: encode the positional information of the tokens
4. Multi-Head Attention block: the core part of the transformer architecture, it computes the attention values for each token (how much should each token attend to the others in the sequence, including itself). Note: For simplicity, we'll only do self-attention (one single head) in this notebook.
    Different types of attention mechanisms are used in the paper:
    - (Standard) Self-Attention: Bidirectional, used in the encoder
    - Masked Self-Attention: Uni-directional where each token can only attend to itself and preceding tokens (future tokens are masked). It's used in the decoder to prevent information leakage from future tokens during training
    - Encoder-Decoder Attention: Bidirectional self-attention where the tokens in the decoder attends to any position of the encoder's output tokens (contextualised embeddings). It's used in the decoder to further enrich the token embeddings' context information from the embeddings output by the encoder after masked self-attention
Every time after going through a self-attention layer, we need to add the residual connections and then normalise the sum with a layer norm layer.
5. MLP block: fully-connected layers coupled with an "add & norm" layer

### Figure 1
Figure 1 visualizes the model architecture of the transformer used in the paper:</br>
<img src="transformer_architecture.png" width=300 alt="figure 1 from transformer paper"/>

## Create position encoding

We will use the formula used in the paper *Attention is all you need* position encoding is:
* PE_(pos, 2i) = sin(pos / 10000^(2i / d_model))
* PE_(pos, 2i+1) = cos(pos / 10000^(2i / d_model))

In [183]:
class PositionEncoding(nn.Module):

    def __init__(self, d_model=6, max_len=6):

        super().__init__()

        pe = torch.zeros(max_len, d_model)

        position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)

        embedding_index = torch.arange(start=0, end=d_model, step=2).float()

        div_term = torch.tensor(10000)**(embedding_index / d_model)

        pe[:, 0::2] = torch.sin(position / div_term)
        pe[:, 1::2] = torch.cos(position / div_term)

        self.register_buffer("pe", pe)
    
    def forward(self, word_embeddings):

        return word_embeddings + self.pe[:word_embeddings.size(0), :]

## Create Attention layers

In [184]:
class Attention(nn.Module):
    
    def __init__(self,
                 d_model=2,
                 mask: torch.tensor = None):
                         
        super().__init__()

        self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)
        self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)

        self.register_buffer(name="mask",
                             tensor=mask)

        self.row_dim = 0
        self.col_dim = 1

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v):

        # Create the Q, K and V matrices
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        # Calculate the similarity score between the queries and values
        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))

        # Scale the similarity score with the square root of d_model
        scaled_sims = sims / torch.tensor((k.size(self.col_dim))**0.5)

        # Mask the scaled similarity scores of the later tokens so that the earlier tokens can't cheat during training
        if self.mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=self.mask, 
                                                  value=-1e9)  # -1e9 is an approximation of negative infinity

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)

        # attention_scores are the contextualised embeddings
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## Create the Encoder

Workflow:
1. Tokenize the source text
2. Pass the tokens through the embedding layer
3. Add positional encodings
4. Pass the embeddings through the attention layer
5. Pass the contextualised embeddings through the mlp block
6. Return the results

In [186]:
class EncoderBlock(nn.Module):
    """
    Replicating the encoder block in the transformer paper. It returns the contextualised embeddings of the source text tokens.
    """
    def __init__(self,
                 num_tokens: int, 
                 d_model: int = 6,
                 max_len: int = 5):

        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Position Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Attention block
        self.attention = Attention(d_model=d_model,
                                   mask=None)

        # Layer norm 1
        self.layer_norm1 = nn.LayerNorm(normalized_shape=d_model)

        # Layer norm 2
        self.layer_norm2 = nn.LayerNorm(normalized_shape=d_model)

        # Linear layer
        self.linear_layer = nn.Linear(in_features=d_model,
                                      out_features=d_model)

    def forward(self, x):

        # Create word embeddings
        word_embeddings = self.we(x)

        # Add position encodings to the word embeddings
        position_encoded = self.pe(word_embeddings)

        # Run the positional encoded embeddings through a self-attention layer
        attention_values = self.attention(position_encoded,
                                          position_encoded,
                                          position_encoded)
        
        # Add residual connections and then apply layer norm
        normalized_attention_values = self.layer_norm1(attention_values + word_embeddings)

        # Run the normalised attention values through a mlp block (linear layer + add & norm layer)
        return self.layer_norm2(self.linear_layer(normalized_attention_values) + normalized_attention_values) # operation fusion

In [187]:
# Test the encoder
encoder = EncoderBlock(num_tokens=len(token_to_id_src))

# Forward pass
encoder_output = encoder(inputs[0])
encoder_output

tensor([[ 1.4281, -1.3818, -0.9599,  1.0252, -0.2459,  0.1343],
        [ 0.4991,  0.1097, -0.8558, -0.4280, -1.1783,  1.8533],
        [-0.7663,  0.0183, -0.1691,  1.6283,  0.7578, -1.4691],
        [ 0.3784, -1.6927, -0.2539, -0.3325,  0.2397,  1.6610],
        [-0.1332,  1.7643, -0.4780,  0.3334,  0.1007, -1.5873]],
       grad_fn=<NativeLayerNormBackward0>)

## Create the Decoder

Workflow:
1. Tokenize the target text
2. Pass the tokens through the embedding layer
3. Add positional encodings
4. Pass them through the masked attention layer
5. Pass them through the mlp block
6. Pass them through the encoder-decoder attention layer
7. Pass them through another mlp block
8. Return the result

In [188]:
class DecoderBlock(nn.Module):
    """
    Replicating the decoder block in the transformer paper. It returns the contextualised embeddings of the target text tokens.
    """
    def __init__(self,
                 num_tokens: int,
                 d_model: int = 6, 
                 max_len: int = 5):
        
        super().__init__()

        # Word Embeddings
        self.we = nn.Embedding(num_embeddings=num_tokens, embedding_dim=d_model)

        # Positional Encodings
        self.pe = PositionEncoding(d_model=d_model, max_len=max_len)

        # Masked Attention layer
        self.decoder_attention = Attention(d_model=d_model,
                                           mask=(torch.tril(torch.ones(max_len, max_len)) == 0)) # the shape of mask is: [seq_len, seq_len])
        
        # Layer Norm 1
        self.layer_norm1 = nn.LayerNorm(normalized_shape=d_model)
        
        # Layer Norm 2
        self.layer_norm2 = nn.LayerNorm(normalized_shape=d_model)
        
        # Layer Norm 3
        self.layer_norm3 = nn.LayerNorm(normalized_shape=d_model)
        
        # Encoder-Decoder Attention layer
        self.encoder_decoder_attention = Attention(d_model=d_model,
                                                   mask=None)
        # Linear layer
        self.linear_layer = nn.Linear(in_features=d_model,
                                      out_features=d_model)

    def forward(self, 
                x,
                encoder_embeddings):

        # Create word embeddings
        word_embeddings = self.we(x)

        # Add positional encodings to the word embeddings
        decoder_positional_embeddings = self.pe(word_embeddings)

        # Run the positionl encoded embeddings through the masked attention layer
        decoder_attention_values = self.decoder_attention(decoder_positional_embeddings,
                                                          decoder_positional_embeddings,
                                                          decoder_positional_embeddings)

        # Add & Norm layer 1
        normalized_decoder_attention_values = self.layer_norm1(decoder_attention_values + word_embeddings)

        # Cross attention layer
        cross_attention_values = self.encoder_decoder_attention(normalized_decoder_attention_values,
                                                                encoder_embeddings,
                                                                encoder_embeddings)
        # Add & Norm layer 2
        normalized_cross_attention_values = self.layer_norm2(cross_attention_values + normalized_decoder_attention_values)

        # MLP block
        return self.layer_norm3(self.linear_layer(normalized_cross_attention_values) + normalized_cross_attention_values)

In [189]:
labels[0]

tensor([ 0,  1,  2, 14, 15])

In [190]:
# Test the decoder
decoder = DecoderBlock(num_tokens=len(token_to_id_target))
decoder(labels[0],
        encoder_embeddings=encoder_output)

tensor([[-1.1319,  0.7287,  1.5256, -0.5534, -1.1188,  0.5498],
        [ 0.7780, -1.1195, -1.0119,  1.7358, -0.0662, -0.3164],
        [ 0.5016, -1.1174, -1.4974,  0.3349,  1.4203,  0.3578],
        [ 1.7898, -1.1354, -0.4221,  0.1237,  0.6128, -0.9688],
        [-2.0327,  0.1249, -0.1078,  0.7596,  1.1153,  0.1408]],
       grad_fn=<NativeLayerNormBackward0>)

## Create the Transformer

Now it's time to put all the puzzles together.

Workflow:
1. Pass the tokenized sequence in source language through the encoder block
2. Pass the tokenized sequence in target language through the decoder block
3. Pass the result from the decoder through the last linear layer
4. Run the resullt through a softmax layer
5. Return the result

In [191]:
class Transformer(L.LightningModule):
    """Replicate the transformer architecture from the paper."""
    def __init__(self,
                 num_tokens_src: int, # number of vocabs in the source language
                 num_tokens_target: int, # number of vocabs in the target language
                 d_model: int = 6,
                 max_len: int = 5):
        
        super().__init__()

        # Create an encoder block
        self.encoder_block = EncoderBlock(
            num_tokens=num_tokens_src,
            d_model=d_model,
            max_len=max_len
        )

        # Create a decoder block
        self.decoder_block = DecoderBlock(
            num_tokens=num_tokens_target,
            d_model=d_model,
            max_len=max_len
        )

        # Create the linear layer
        self.linear_layer = nn.Linear(in_features=d_model,
                                      out_features=num_tokens_target) # how many vocabs are there in the target language tokenizer
        
        # Setup loss function
        self.loss_fn = nn.CrossEntropyLoss()
        
    def forward(self, x_src, x_target): # the target sequence is also needed for teacher forcing during training
        
        # Pass the source sequence to the encoder block
        encoder_embeddings = self.encoder_block(x_src)

        # Pass the target sequence to the decoder block
        decoder_embeddings = self.decoder_block(x=x_target,
                                                encoder_embeddings=encoder_embeddings)
        
        # Pass the result through the final linear layer
        return self.linear_layer(decoder_embeddings)
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=0.1)
    
    def training_step(self, batch, batch_idx):
        input_tokens, labels = batch
        outputs = self.forward(x_src=input_tokens[0],
                               x_target=labels[0])
        loss = self.loss_fn(outputs, labels[0])
        return loss

In [192]:
# Test the transformer
model = Transformer(num_tokens_src=len(token_to_id_src),
                    num_tokens_target=len(token_to_id_target))
model(x_src=inputs[0],
      x_target=labels[0])

tensor([[-0.2152, -0.4546,  0.0394,  0.0677, -0.0545,  0.8017, -0.0953,  0.4340,
         -0.6535,  0.0716,  0.5408,  0.2377, -0.7780, -0.2315,  0.6845, -0.2330],
        [ 0.1591, -0.4489, -0.5699, -0.5113, -0.5825, -0.1202,  0.1925,  0.0018,
          0.9156,  0.6455,  0.7619, -0.7661, -0.0649,  0.4199,  0.2110, -0.2469],
        [ 0.1039, -0.5878,  1.1718, -0.1426,  0.8702,  0.3124, -0.1313, -0.1212,
          0.5813, -0.1304,  0.2518,  0.6831, -0.7163,  1.0081,  0.3152,  0.6601],
        [-0.4179, -0.4270, -1.2409,  0.0746, -1.2202,  0.4409,  0.2123,  0.5776,
         -0.4422,  0.5272,  0.9960, -0.7501, -0.2821, -0.2973,  0.4988, -0.6226],
        [-0.2152, -0.1593, -1.1012, -0.3428,  0.2441,  0.1139,  0.6229, -0.3206,
          0.6523, -0.0179,  0.8213, -0.7836, -0.0246,  0.6720,  0.9047,  0.5303]],
       grad_fn=<AddmmBackward0>)

## Train the model with `Lightning.Trainer.fit()`

In [193]:
### Create dataloader
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(inputs, labels)
dataloader = DataLoader(dataset)
len(dataloader)

8

In [194]:
# Create an instance of trainer
trainer = L.Trainer(max_epochs=60,
                    accelerator="auto")
trainer.fit(model,
            train_dataloaders=dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | encoder_block | EncoderBlock     | 246    | train
1 | decoder_block | DecoderBlock     | 390    | train
2 | linear_layer  | Linear           | 112    | train
3 | loss_fn       | CrossEntropyLoss | 0      | train
-----------------------------------------------------------
748       Trainable params
0         Non-trainable params
748       Total params
0.003     Total estimated model params size (MB)
/Users/edison/Git/pytorch-lightning-deep-learning/myenv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/edison/Git/

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=60` reached.


In [204]:
# Test the transformer
for i in range(len(inputs)):
      y_logits = model(x_src=inputs[i],
                  x_target=labels[i])

      pred = list(map(lambda x: id_to_token_target[x.item()], torch.argmax(y_logits, dim=1))) 
      print(f"Input: {list(map(lambda x: id_to_token_src[x.item()], inputs[i]))} | Pred: {pred}")

Input: ['<SOS>', 'i', 'love', 'you', '<EOS>'] | Pred: ['<SOS>', 'te', 'amo', '<EOS>', '<PAD>']
Input: ['<SOS>', 'you', 'love', 'me', '<EOS>'] | Pred: ['<SOS>', 'me', 'amas', '<EOS>', '<PAD>']
Input: ['<SOS>', 'i', 'see', 'you', '<EOS>'] | Pred: ['<SOS>', 'te', 'veo', '<EOS>', '<PAD>']
Input: ['<SOS>', 'you', 'see', 'me', '<EOS>'] | Pred: ['<SOS>', 'me', 'ves', '<EOS>', '<PAD>']
Input: ['<SOS>', 'i', 'like', 'pizza', '<EOS>'] | Pred: ['<SOS>', 'me', 'gusta', 'pizza', '<EOS>']
Input: ['<SOS>', 'you', 'like', 'cake', '<EOS>'] | Pred: ['<SOS>', 'te', 'gusta', 'pastel', '<EOS>']
Input: ['<SOS>', 'i', 'eat', 'pizza', '<EOS>'] | Pred: ['<SOS>', 'yo', 'como', 'pizza', '<EOS>']
Input: ['<SOS>', 'you', 'eat', 'cake', '<EOS>'] | Pred: ['<SOS>', 'tú', 'comes', 'pastel', '<EOS>']
