# The Annotated Transformer ++

The idea of this notebook is to make it easier even for non-researchers to understand the omnipresent transformer model!

In this notebook you'll learn:

✅ What are transformers exactly? <br/>
✅ How to train them? <br/>
✅ How to use them? (machine translation example) <br/>

After you complete this one you'll have a much better understanding of transformers!

So, let's start!

---

## What the heck are transformers and how they came to be?

Transformers were originally proposed by Vaswani et al. in a seminal paper called [Attention Is All You Need](https://arxiv.org/pdf/1706.03762.pdf).

You probably heard of transformers one way or another. GPT-3 and BERT to name a few well known ones 🦄. The main idea is that they showed that you don't have to use recurrent or convolutional layers and that simple architecture coupled with attention is super powerful. It gave the benefit of much better long-range dependency modeling and the architecture itself is highly parallelizable (💻💻💻) which leads to better compute efficiency!

Here is how their beautifully simple architecture looks like:

<img src="data/readme_pics/transformer_architecture.PNG" alt="transformer architecture" align="center" style="width: 350px;"/> <br/>

---

That was everything you need to know for now! If you need further help understanding all of the details I created this [in-depth overview of the paper:](https://www.youtube.com/watch?v=cbYxHkgkSVs)

<a href="https://www.youtube.com/watch?v=cbYxHkgkSVs" target="_blank"><img src="https://img.youtube.com/vi/cbYxHkgkSVs/0.jpg" 
alt="An in-depth overview of the Attention Is All You Need paper" width="480" align="left" height="360" border="10" /></a>

In [1]:
# I always like to structure my imports into Python's native libs,
# stuff I installed via conda/pip and local file imports (we don't have those here)

import math
import copy
import os
import time
import enum
import argparse

# Visualization related imports
import matplotlib.pyplot as plt
import seaborn

# Deep learning related imports
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from torch.hub import download_url_to_file

# Data manipulation related imports
from torchtext.data import Dataset, BucketIterator, Field, Example
from torchtext.data.utils import interleave_keys
from torchtext import datasets
from torchtext.data import Example
import spacy

# BLEU (the metric we'll be using for the machine translation task)
from nltk.translate.bleu_score import corpus_bleu

In [2]:
# Let's create some constants to make the code a bit cleaner

# Architecture related constants taken from the paper
BASELINE_MODEL_NUMBER_OF_LAYERS = 6
BASELINE_MODEL_DIMENSION = 512
BASELINE_MODEL_NUMBER_OF_HEADS = 8
BASELINE_MODEL_DROPOUT_PROB = 0.1
BASELINE_MODEL_LABEL_SMOOTHING_VALUE = 0.1

CHECKPOINTS_PATH = os.path.join(os.getcwd(), 'models', 'checkpoints') # semi-trained models during training will be dumped here
BINARIES_PATH = os.path.join(os.getcwd(), 'models', 'binaries') # location where trained models are located
DATA_DIR_PATH = os.path.join(os.getcwd(), 'data') # training data will be stored here

os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
os.makedirs(BINARIES_PATH, exist_ok=True)
os.makedirs(DATA_DIR_PATH, exist_ok=True)

# Special token symbols used later in the data section
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
PAD_TOKEN = "<pad>"

# Part 1: Understanding the model (transformer)

There are 2 really important parts of any deep learning pipeline:
1. Model 🦄
2. Data 📜

Let's start with the model (🦄) first and understand how the transformer is architectured.

I'll take a top-down approach here.

We should first understand how all of the modules fit together (**level 0**) and then we'll start digging deeper into the nitty-gritty details!

In [3]:
class Transformer(nn.Module):

    def __init__(self, model_dimension, src_vocab_size, trg_vocab_size, number_of_heads, number_of_layers, dropout_probability, log_attention_weights=False):
        super().__init__()

        # Embeds source/target token ids into embedding vectors
        self.src_embedding = Embedding(src_vocab_size, model_dimension)
        self.trg_embedding = Embedding(trg_vocab_size, model_dimension)

        # Adds positional information to source/target token's embedding vector
        # (otherwise we'd lose the positional information which is important in human languages)
        self.src_pos_embedding = PositionalEncoding(model_dimension, dropout_probability)
        self.trg_pos_embedding = PositionalEncoding(model_dimension, dropout_probability)

        # All of these will get deep-copied multiple times internally
        mha = MultiHeadedAttention(model_dimension, number_of_heads, dropout_probability, log_attention_weights)
        pwn = PositionwiseFeedForwardNet(model_dimension, dropout_probability)
        encoder_layer = EncoderLayer(model_dimension, dropout_probability, mha, pwn)
        decoder_layer = DecoderLayer(model_dimension, dropout_probability, mha, pwn)

        # Encoder and Decoder stacks
        self.encoder = Encoder(encoder_layer, number_of_layers)
        self.decoder = Decoder(decoder_layer, number_of_layers)

        # Converts final target token representations into log probability vectors of trg_vocab_size dimensionality
        # Why log? -> PyTorch's nn.KLDivLoss expects log probabilities
        self.decoder_generator = DecoderGenerator(model_dimension, trg_vocab_size)
        
        self.init_params()

    # This part wasn't mentioned in the paper, but it's super important!
    def init_params(self):
        # I tested both PyTorch's default initialization and this, and xavier has tremendous impact! I didn't expect
        # that the model's perf, with normalization layers, is so dependent on the choice of weight initialization.
        for name, p in self.named_parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src_token_ids_batch, trg_token_ids_batch, src_mask, trg_mask):
        src_representations_batch = self.encode(src_token_ids_batch, src_mask)
        trg_log_probs = self.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask)
        
        return trg_log_probs

    # Modularize into encode/decode functions for optimizing the decoding/translation process (we'll get to it later)
    def encode(self, src_token_ids_batch, src_mask):
        # Shape = (B, S, D) , where B - batch size, S - longest source token-sequence length and D - model dimension
        # The whole encoder stack perserves this shape
        src_embeddings_batch = self.src_embedding(src_token_ids_batch)  # get embedding vectors for src token ids
        src_embeddings_batch = self.src_pos_embedding(src_embeddings_batch)  # add positional embedding
        src_representations_batch = self.encoder(src_embeddings_batch, src_mask)  # forward pass through the encoder

        return src_representations_batch

    def decode(self, trg_token_ids_batch, src_representations_batch, trg_mask, src_mask):
        trg_embeddings_batch = self.trg_embedding(trg_token_ids_batch)  # get embedding vectors for trg token ids
        trg_embeddings_batch = self.trg_pos_embedding(trg_embeddings_batch)  # add positional embedding
        
        # Shape (B, T, D), where B - batch size, T - longest target token-sequence length and D - model dimension
        trg_representations_batch = self.decoder(trg_embeddings_batch, src_representations_batch, trg_mask, src_mask)

        # After this line we'll have a shape (B, T, V), where V - target vocab size, 
        # decoder generator does a simple linear projection followed by log softmax
        trg_log_probs = self.decoder_generator(trg_representations_batch)

        # Reshape into (B*T, V) as that's a suitable format for passing it into KL div loss
        trg_log_probs = trg_log_probs.reshape(-1, trg_log_probs.shape[-1])

        return trg_log_probs  # the reason I use log here is that PyTorch's nn.KLDivLoss expects log probabilities

# Encoder and Decoder

Stay with me here! 

These initial parts will probably be the most confusing as you're not yet familiar with terminology and I'll reference many things yet to be defined.

Get used to it. Whether you're reading research papers or code things can't always get linearized. <br/>
Think of it as a graph which can't get converted into a [DAG](https://en.wikipedia.org/wiki/Directed_acyclic_graph) (direct acyclic graph) because we have circular dependencies!

Let's go **1 level** deeper - and see exactly how encoder and decoder are structured.

<img src="data/readme_pics/deeper_inception.jpg" alt="we need to go deeper" align="center" style="width: 350px;"/> <br/>

High level overview of the functionality:

1. The encoder takes as the input a batch of source sequences whose tokens have already been embedded.
2. It then does 6 iterations (6 layers for the base transformer) of mixing those via attention.
3. The final output will later get consumed by the decoder. That's it.

In [4]:
class Encoder(nn.Module):

    def __init__(self, encoder_layer, number_of_layers):
        super().__init__()
        assert isinstance(encoder_layer, EncoderLayer), f'Expected EncoderLayer got {type(encoder_layer)}.'

        # Get a list of 'number_of_layers' independent encoder layers
        self.encoder_layers = get_clones(encoder_layer, number_of_layers)  
        self.norm = nn.LayerNorm(encoder_layer.model_dimension)

    def forward(self, src_embeddings_batch, src_mask):
        # Just update the naming so as to reflect the semantics of what this var will become (the initial encoder layer
        # has embedding vectors as input but later layers have richer token representations)
        src_representations_batch = src_embeddings_batch

        # Forward pass through the encoder stack
        for encoder_layer in self.encoder_layers:
            # src_mask's role is to mask/ignore padded token representations in the multi-headed self-attention module
            src_representations_batch = encoder_layer(src_representations_batch, src_mask)

        # Not mentioned explicitly in the paper 
        # (a consequence of using LayerNorm before instead of after the SublayerLogic module)
        return self.norm(src_representations_batch)


class EncoderLayer(nn.Module):

    def __init__(self, model_dimension, dropout_probability, multi_headed_attention, pointwise_net):
        super().__init__()
        num_of_sublayers_encoder = 2
        self.sublayers = get_clones(SublayerLogic(model_dimension, dropout_probability), num_of_sublayers_encoder)

        self.multi_headed_attention = multi_headed_attention
        self.pointwise_net = pointwise_net

        self.model_dimension = model_dimension

    def forward(self, src_representations_batch, src_mask):
        # Define an anonymous (lambda) function which only takes src_representations_batch (srb) as input,
        # this way we have a uniform interface for the sublayer logic.
        encoder_self_attention = lambda srb: self.multi_headed_attention(query=srb, key=srb, value=srb, mask=src_mask)

        # Self-attention MHA sublayer followed by point-wise feed forward net sublayer
        # SublayerLogic takes as the input the data and the logic it should execute (attention/feedforward)
        src_representations_batch = self.sublayers[0](src_representations_batch, encoder_self_attention)
        src_representations_batch = self.sublayers[1](src_representations_batch, self.pointwise_net)

        return src_representations_batch

The decoder has a similar structure to encoder.

1. It again starts with a batch of target sequences whose tokens have already been embedded.
2. It then does 6 iterations (6 layers for the base transformer) of mixing those via attention (this time also attending to source token representations)
3. The final output i.e. target token representations go into the DecoderGenerator module which will convert them
into log probabilities

A really important difference is that the decoder uses **causal masking** so as to prevent tokens from looking into the future.

Here is an example of a causal mask used in decoder's self attention module:
<img src="data/readme_pics/causal_mask.PNG" alt="casual mask" align="center"/> <br/>

What this mask tells is that tokens on the left side of the chart can only attend to tokens (columns) whose values are set to 1.0 (True).

As you can see the token can also attend to itself and to tokens that come before - hence **causal**.

In [5]:
class Decoder(nn.Module):

    def __init__(self, decoder_layer, number_of_layers):
        super().__init__()
        assert isinstance(decoder_layer, DecoderLayer), f'Expected DecoderLayer got {type(decoder_layer)}.'

        self.decoder_layers = get_clones(decoder_layer, number_of_layers)
        self.norm = nn.LayerNorm(decoder_layer.model_dimension)

    def forward(self, trg_embeddings_batch, src_representations_batch, trg_mask, src_mask):
        # Just update the naming so as to reflect the semantics of what this var will become
        trg_representations_batch = trg_embeddings_batch

        # Forward pass through the decoder stack
        for decoder_layer in self.decoder_layers:
            # Target mask masks pad tokens as well as future tokens (current target token can't look forward)
            trg_representations_batch = decoder_layer(trg_representations_batch, src_representations_batch, trg_mask, src_mask)

        # Not mentioned explicitly in the paper 
        # (a consequence of using LayerNorm before instead of after the SublayerLogic module)
        return self.norm(trg_representations_batch)


class DecoderLayer(nn.Module):

    def __init__(self, model_dimension, dropout_probability, multi_headed_attention, pointwise_net):
        super().__init__()
        num_of_sublayers_decoder = 3
        self.sublayers = get_clones(SublayerLogic(model_dimension, dropout_probability), num_of_sublayers_decoder)

        self.trg_multi_headed_attention = copy.deepcopy(multi_headed_attention)
        self.src_multi_headed_attention = copy.deepcopy(multi_headed_attention)
        self.pointwise_net = pointwise_net

        self.model_dimension = model_dimension

    def forward(self, trg_representations_batch, src_representations_batch, trg_mask, src_mask):
        # Define an anonymous (lambda) function which only takes trg_representations_batch (trb - funny name I know)
        # as input - this way we have a uniform interface for the sublayer logic.
        # The inputs which are not passed into lambdas (masks/srb) are "cached" here that's why the thing works.
        srb = src_representations_batch  # simple/short alias
        decoder_trg_self_attention = lambda trb: self.trg_multi_headed_attention(query=trb, key=trb, value=trb, mask=trg_mask)
        decoder_src_attention = lambda trb: self.src_multi_headed_attention(query=trb, key=srb, value=srb, mask=src_mask)

        # Self-attention MHA sublayer followed by a source-attending MHA and point-wise feed forward net sublayer
        trg_representations_batch = self.sublayers[0](trg_representations_batch, decoder_trg_self_attention)
        trg_representations_batch = self.sublayers[1](trg_representations_batch, decoder_src_attention)
        trg_representations_batch = self.sublayers[2](trg_representations_batch, self.pointwise_net)

        return trg_representations_batch

And the final **level 1** module is the DecoderGenerator (level 0 was the topmost Transformer class).

This module does 2 things:
1. Projects the final decoder token representations of dimension D into V dimensions (target vocabulary size)
2. Applies a log softmax ([PyTorch's KL div loss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) expects log probabilities again)

In [6]:
class DecoderGenerator(nn.Module):
    def __init__(self, model_dimension, vocab_size):
        super().__init__()

        self.linear = nn.Linear(model_dimension, vocab_size)

        # -1 stands for apply the log-softmax along the last dimension i.e. over the vocab dimension as the output from
        # the linear layer has shape (B, T, V), B - batch size, T - max target token-sequence, V - target vocab size
        
        # again using log softmax as PyTorch's nn.KLDivLoss expects log probabilities (just a technical detail)
        self.log_softmax = nn.LogSoftmax(dim=-1)

    def forward(self, trg_representations_batch):
        # Project from D (model dimension) into V (target vocab size) and apply the log softmax along V dimension
        return self.log_softmax(self.linear(trg_representations_batch))

# Going even deeper

Phew, give yourself a pat on the back and take a deep breath as we're diving into final details related to transformer's architecture.

Here is how the positional encodings look like:

<img src="data/readme_pics/positional_encoding_visualized.jpg"/>

In [7]:
class SublayerLogic(nn.Module):
    def __init__(self, model_dimension, dropout_probability):
        super().__init__()
        self.norm = nn.LayerNorm(model_dimension)
        self.dropout = nn.Dropout(p=dropout_probability)

    # Note: the original paper had LayerNorm AFTER the residual connection and addition operation
    # multiple experiments I found showed that it's more effective to do it BEFORE, how did they figure out which one is
    # better? Experiments! There is a similar thing in DCGAN and elsewhere.
    def forward(self, representations_batch, sublayer_module):
        # Residual connection between input and sublayer output, details: Page 7, Chapter 5.4 "Regularization",
        return representations_batch + self.dropout(sublayer_module(self.norm(representations_batch)))


class PositionwiseFeedForwardNet(nn.Module):
    """
        It's position-wise because this feed forward net will be independently applied to every token's representation.

        This net will basically be applied independently to every token's representation (you can think of it as if
        there was a nested for-loop going over the batch size and max token sequence length dimensions
        and applied this network to token representations. PyTorch does this auto-magically behind the scenes.

    """
    def __init__(self, model_dimension, dropout_probability, width_mult=4):
        super().__init__()

        self.linear1 = nn.Linear(model_dimension, width_mult * model_dimension)
        self.linear2 = nn.Linear(width_mult * model_dimension, model_dimension)

        # This dropout layer is not explicitly mentioned in the paper but it's common to use to avoid over-fitting
        self.dropout = nn.Dropout(p=dropout_probability)
        self.relu = nn.ReLU()

    # Representations_batch's shape = (B - batch size, S/T - max token sequence length, D- model dimension).
    def forward(self, representations_batch):
        return self.linear2(self.dropout(self.relu(self.linear1(representations_batch))))


#
# Input modules
#


class Embedding(nn.Module):

    def __init__(self, vocab_size, model_dimension):
        super().__init__()
        self.embeddings_table = nn.Embedding(vocab_size, model_dimension)
        self.model_dimension = model_dimension

    def forward(self, token_ids_batch):
        assert token_ids_batch.ndim == 2, f'Expected: (batch size, max token sequence length), got {token_ids_batch.shape}'

        # token_ids_batch has shape (B, S/T), where B - batch size, S/T max src/trg token-sequence length
        # Final shape will be (B, S/T, D) where D is the model dimension, every token id has associated vector
        embeddings = self.embeddings_table(token_ids_batch)

        # (stated in the paper) multiply the embedding weights by the square root of model dimension
        # Page 5, Chapter 3.4 "Embeddings and Softmax"
        return embeddings * math.sqrt(self.model_dimension)


class PositionalEncoding(nn.Module):

    def __init__(self, model_dimension, dropout_probability, expected_max_sequence_length=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_probability)

        # (stated in the paper) Use sine functions whose frequencies form a geometric progression as position encodings,
        # (learning encodings will also work so feel free to change it!). Page 6, Chapter 3.5 "Positional Encoding"
        position_id = torch.arange(0, expected_max_sequence_length).unsqueeze(1)
        frequencies = torch.pow(10000., -torch.arange(0, model_dimension, 2, dtype=torch.float) / model_dimension)

        # Checkout playground.py for visualization of how these look like (it's super simple don't get scared)
        positional_encodings_table = torch.zeros(expected_max_sequence_length, model_dimension)
        positional_encodings_table[:, 0::2] = torch.sin(position_id * frequencies)  # sine on even positions
        positional_encodings_table[:, 1::2] = torch.cos(position_id * frequencies)  # cosine on odd positions

        # Register buffer because we want to save the positional encodings table inside state_dict even though
        # these are not trainable (not model's parameters) so they otherwise would be excluded from the state_dict
        self.register_buffer('positional_encodings_table', positional_encodings_table)

    def forward(self, embeddings_batch):
        assert embeddings_batch.ndim == 3 and embeddings_batch.shape[-1] == self.positional_encodings_table.shape[1], \
            f'Expected (batch size, max token sequence length, model dimension) got {embeddings_batch.shape}'

        # embedding_batch's shape = (B, S/T, D), where S/T max src/trg token-sequence length, D - model dimension
        # So here we get (S/T, D) shape which will get broad-casted to (B, S/T, D) when we try and add it to embeddings
        positional_encodings = self.positional_encodings_table[:embeddings_batch.shape[1]]

        # (stated in the paper) Applying dropout to the sum of positional encodings and token embeddings
        # Page 7, Chapter 5.4 "Regularization"
        return self.dropout(embeddings_batch + positional_encodings)


#
# Helper model functions
#


def get_clones(module, num_of_deep_copies):
    # Create deep copies so that we can tweak each module's weights independently
    return nn.ModuleList([copy.deepcopy(module) for _ in range(num_of_deep_copies)])

# The attention part of "Attention Is All You Need"

And last but definitely not least - the multi headed attention module.

In [8]:
class MultiHeadedAttention(nn.Module):
    """
        This module already exists in PyTorch. The reason I implemented it here from scratch is that
        PyTorch implementation is super complicated as they made it as generic/robust as possible whereas
        on the other hand I only want to support a limited use-case.

        Also this is arguable the most important architectural component in the Transformer model.

        Additional note:
        This is conceptually super easy stuff. It's just that matrix implementation makes things a bit less intuitive.
        If you take your time and go through the code and figure out all of the dimensions + write stuff down on paper
        you'll understand everything. Also do check out this amazing blog for conceptual understanding:

        https://jalammar.github.io/illustrated-transformer/

        Optimization notes:

        qkv_nets could be replaced by Parameter(torch.empty(3 * model_dimension, model_dimension)) and one more matrix
        for bias, which would make the implementation a bit more optimized. For the sake of easier understanding though,
        I'm doing it like this - using 3 "feed forward nets" (without activation/identity hence the quotation marks).
        Conceptually both implementations are the same.

        PyTorch's query/key/value are of different shape namely (max token sequence length, batch size, model dimension)
        whereas I'm using (batch size, max token sequence length, model dimension) because it's easier to understand
        and consistent with computer vision apps (batch dimension is always first followed by the number of channels (C)
        and image's spatial dimensions height (H) and width (W) -> (B, C, H, W).

        This has an important optimization implication, they can reshape their matrix into (B*NH, S/T, HD)
        (where B - batch size, S/T - max src/trg sequence length, NH - number of heads, HD - head dimension)
        in a single step and I can only get to (B, NH, S/T, HD) in single step
        (I could call contiguous() followed by view but that's expensive as it would incur additional matrix copy)

    """

    def __init__(self, model_dimension, number_of_heads, dropout_probability, log_attention_weights):
        super().__init__()
        assert model_dimension % number_of_heads == 0, f'Model dimension must be divisible by the number of heads.'

        self.head_dimension = int(model_dimension / number_of_heads)
        self.number_of_heads = number_of_heads

        self.qkv_nets = get_clones(nn.Linear(model_dimension, model_dimension), 3)  # identity activation hence "nets"
        self.out_projection_net = nn.Linear(model_dimension, model_dimension)

        self.attention_dropout = nn.Dropout(p=dropout_probability)  # no pun intended, not explicitly mentioned in paper
        self.softmax = nn.Softmax(dim=-1)  # -1 stands for apply the softmax along the last dimension

        self.log_attention_weights = log_attention_weights  # should we log attention weights
        self.attention_weights = None  # for visualization purposes, I cache the weights here (translation_script.py)

    def attention(self, query, key, value, mask):
        # Step 1: Scaled dot-product attention, Page 4, Chapter 3.2.1 "Scaled Dot-Product Attention"
        # Notation: B - batch size, S/T max src/trg token-sequence length, NH - number of heads, HD - head dimension
        # query/key/value shape = (B, NH, S/T, HD), scores shape = (B, NH, S, S), (B, NH, T, T) or (B, NH, T, S)
        # scores have different shapes as MHA is used in 3 contexts, self attention for src/trg and source attending MHA
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dimension)

        # Step 2: Optionally mask tokens whose representations we want to ignore by setting a big negative number
        # to locations corresponding to those tokens (force softmax to output 0 probability on those locations).
        # mask shape = (B, 1, 1, S) or (B, 1, T, T) will get broad-casted (copied) as needed to match scores shape
        if mask is not None:
            scores.masked_fill_(mask == torch.tensor(False), float("-inf"))

        # Step 3: Calculate the attention weights - how much should we attend to surrounding token representations
        attention_weights = self.softmax(scores)

        # Step 4: Not defined in the original paper apply dropout to attention weights as well
        attention_weights = self.attention_dropout(attention_weights)

        # Step 5: based on attention weights calculate new token representations
        # attention_weights shape = (B, NH, S, S)/(B, NH, T, T) or (B, NH, T, S), value shape = (B, NH, S/T, HD)
        # Final shape (B, NH, S, HD) for source MHAs or (B, NH, T, HD) target MHAs (again MHAs are used in 3 contexts)
        intermediate_token_representations = torch.matmul(attention_weights, value)

        return intermediate_token_representations, attention_weights  # attention weights for visualization purposes

    def forward(self, query, key, value, mask):
        batch_size = query.shape[0]

        # Step 1: Input linear projection
        # Notation: B - batch size, NH - number of heads, S/T - max src/trg token-sequence length, HD - head dimension
        # Shape goes from (B, S/T, NH*HD) over (B, S/T, NH, HD) to (B, NH, S/T, HD) (NH*HD=D where D is model dimension)
        query, key, value = [net(x).view(batch_size, -1, self.number_of_heads, self.head_dimension).transpose(1, 2)
                             for net, x in zip(self.qkv_nets, (query, key, value))]

        # Step 2: Apply attention - compare query with key and use that to combine values (see the function for details)
        intermediate_token_representations, attention_weights = self.attention(query, key, value, mask)

        # Potentially, for visualization purposes, log the attention weights, turn off during training though!
        # I had memory problems when I leave this on by default
        if self.log_attention_weights:
            self.attention_weights = attention_weights

        # Step 3: Reshape from (B, NH, S/T, HD) over (B, S/T, NH, HD) (via transpose) into (B, S/T, NHxHD) which is
        # the same shape as in the beginning of this forward function i.e. input to MHA (multi-head attention) module
        reshaped = intermediate_token_representations.transpose(1, 2).reshape(batch_size, -1, self.number_of_heads * self.head_dimension)

        # Step 4: Output linear projection
        token_representations = self.out_projection_net(reshaped)

        return token_representations

That's it! You've done it, you now have a better understanding of how the transformer is structured. 

Level 2 unlocked. 😍 Let's proceed.

**Let me stress this - it takes time to understand this project don't get frustrated! ❤️ (it took me ~3 weeks to code this up!)**

# Part 2: Understanding the data pipeline

The second important part of any deep learning pipeline is the data engine or pipeline. 📜

It's the infrastructure responsible for feeding the data to your transformer (or to ML model in general).

We'll be using PyTorch's torchtext which has it's problems but it gets the job done so let's continue! <br/>

Let's first define functions which will directly feed the data to the transformer.

In [9]:
# Let's define a couple of useful enum classes
class DatasetType(enum.Enum):
    IWSLT = 0,
    WMT14 = 1


class LanguageDirection(enum.Enum):
    E2G = 0,
    G2E = 1
    

def get_datasets_and_vocabs(dataset_path, language_direction, use_iwslt=True, use_caching_mechanism=True):
    german_to_english = language_direction == LanguageDirection.G2E.name
    spacy_de = spacy.load('de_core_news_sm')
    spacy_en = spacy.load('en_core_web_sm')

    def tokenize_de(text):
        return [tok.text for tok in spacy_de.tokenizer(text)]

    def tokenize_en(text):
        return [tok.text for tok in spacy_en.tokenizer(text)]

    src_tokenizer = tokenize_de if german_to_english else tokenize_en
    trg_tokenizer = tokenize_en if german_to_english else tokenize_de
    
    # These will tokenize the source/target sentences
    # batch first set to true as my transformer is expecting that format (that's consistent with the format
    # used in  computer vision), namely (B, C, H, W) -> batch size, number of channels, height and width
    src_field_processor = Field(tokenize=src_tokenizer, pad_token=PAD_TOKEN, batch_first=True)
    trg_field_processor = Field(tokenize=trg_tokenizer, init_token=BOS_TOKEN, eos_token=EOS_TOKEN, pad_token=PAD_TOKEN, batch_first=True)

    fields = [('src', src_field_processor), ('trg', trg_field_processor)]
    MAX_LEN = 100  # filter out examples that have more than MAX_LEN tokens
    filter_pred = lambda x: len(x.src) <= MAX_LEN and len(x.trg) <= MAX_LEN

    # Caching mechanism stuff - avoid redoing the tokenization over and over again
    prefix = 'de_en' if german_to_english else 'en_de'
    prefix += '_iwslt' if use_iwslt else '_wmt14'
    train_cache_path = os.path.join(dataset_path, f'{prefix}_train_cache.csv')
    val_cache_path = os.path.join(dataset_path, f'{prefix}_val_cache.csv')
    test_cache_path = os.path.join(dataset_path, f'{prefix}_test_cache.csv')

    # This simple caching mechanism gave me ~30x speedup on my machine! From ~70s -> ~2.5s!
    ts = time.time()
    if not use_caching_mechanism or not (os.path.exists(train_cache_path) and os.path.exists(val_cache_path)):
        # dataset objects have a list of examples where example is simply an empty Python Object that has
        # .src and .trg attributes which contain a tokenized list of strings (created by tokenize_en and tokenize_de).
        # It's that simple, we can consider our datasets as a table with 2 columns 'src' and 'trg'
        # each containing fields with tokenized strings from source and target languages
        src_ext = '.de' if german_to_english else '.en'
        trg_ext = '.en' if german_to_english else '.de'
        dataset_split_fn = datasets.IWSLT.splits if use_iwslt else datasets.WMT14.splits
        
        train_dataset, val_dataset, test_dataset = dataset_split_fn(
            exts=(src_ext, trg_ext),
            fields=fields,
            root=dataset_path,
            filter_pred=filter_pred
        )

        save_cache(train_cache_path, train_dataset)
        save_cache(val_cache_path, val_dataset)
        save_cache(test_cache_path, test_dataset)
    else:
        # it's actually better to load from cache as we'll get rid of '\xa0', '\xa0 ' and '\x85' unicode characters
        # which we don't need and which SpaCy unfortunately includes as tokens.
        train_dataset, val_dataset = DatasetWrapper.get_train_and_val_datasets(
            train_cache_path,
            val_cache_path,
            fields,
            filter_pred=filter_pred
        )

    print(f'Time it took to prepare the data: {time.time() - ts:3f} seconds.')

    MIN_FREQ = 2
    # __getattr__ implementation in the base Dataset class enables us to call .src on Dataset objects even though
    # we only have a list of examples in the Dataset object and the example itself had .src attribute.
    # Implementation will yield examples and call .src/.trg attributes on them (and those contain tokenized lists)
    src_field_processor.build_vocab(train_dataset.src, min_freq=MIN_FREQ)
    trg_field_processor.build_vocab(train_dataset.trg, min_freq=MIN_FREQ)

    return train_dataset, val_dataset, src_field_processor, trg_field_processor


# https://github.com/pytorch/text/issues/536#issuecomment-719945594 <- there is a "bug" in BucketIterator i.e. it's
# description is misleading as it won't group examples of similar length unless you set sort_within_batch to True!
def get_data_loaders(dataset_path, language_direction, dataset_name, batch_size, device):
    train_dataset, val_dataset, src_field_processor, trg_field_processor = get_datasets_and_vocabs(dataset_path, language_direction, dataset_name == DatasetType.IWSLT.name)

    train_token_ids_loader, val_token_ids_loader = BucketIterator.splits(
     datasets=(train_dataset, val_dataset),
     batch_size=batch_size,
     device=device,
     sort_within_batch=True,  # this part is really important otherwise we won't group similar length sentences
     batch_size_fn=batch_size_fn  # this helps us max out GPU's VRAM
    )

    return train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor


global longest_src_sentence, longest_trg_sentence


def batch_size_fn(new_example, count, sofar):
    """
        If we use this function in the BucketIterator the batch_size is no longer the number of examples/sentences
        in a batch but a number of tokens in a batch - which allows us to max out VRAM on a given GPU.

        Example: if we don't use this function and we set batch size to say 10 we will sometimes end up with
        a tensor of size (10, 100) because the longest sentence had a size of 100 tokens but other times we'll end
        up with a size of (10, 5) because the longest sentence had only 5 tokens!

        With this function what we do is we specify that source and target tensors can't go over a certain number
        of tokens like 1000. So usually either source or target tensors will contain around 1000 tokens and
        in worst case both will be really close to a 1000 tokens each. If that is still below max VRAM availabe on
        the system we're using the max potential of our GPU w.r.t. VRAM.

        Note: to understand this function you unfortunately would probably have to dig deeper into torch text's
        source code.

    """
    global longest_src_sentence, longest_trg_sentence

    if count == 1:
        longest_src_sentence = 0
        longest_trg_sentence = 0

    longest_src_sentence = max(longest_src_sentence, len(new_example.src))
    # 2 because of start/end of sentence tokens (<s> and </s>)
    longest_trg_sentence = max(longest_trg_sentence, len(new_example.trg) + 2)

    num_of_tokens_in_src_tensor = count * longest_src_sentence
    num_of_tokens_in_trg_tensor = count * longest_trg_sentence

    return max(num_of_tokens_in_src_tensor, num_of_tokens_in_trg_tensor)

Those were the main functions we need to later feed the data to the transformer inside the training loop!

Now let's just define some functions which will get called in the training loop and will provide us with:
1. correct source/target masks
2. correct batches of token ids

In [10]:
def get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id):
    batch_size = src_token_ids_batch.shape[0]

    # src_mask shape = (B, 1, 1, S) check out attention function in transformer_model.py where masks are applied
    # src_mask only masks pad tokens as we want to ignore their representations (no information in there...)
    src_mask = (src_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1)
    num_src_tokens = torch.sum(src_mask.long())

    return src_mask, num_src_tokens


def get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id):
    batch_size = trg_token_ids_batch.shape[0]
    device = trg_token_ids_batch.device

    # Same as src_mask but we additionally want to mask tokens from looking forward into the future tokens
    # Note: wherever the mask value is true we want to attend to that token, otherwise we mask (ignore) it.
    sequence_length = trg_token_ids_batch.shape[1]  # trg_token_ids shape = (B, T) where T max trg token-sequence length
    trg_padding_mask = (trg_token_ids_batch != pad_token_id).view(batch_size, 1, 1, -1)  # shape = (B, 1, 1, T)
    trg_no_look_forward_mask = torch.triu(torch.ones((1, 1, sequence_length, sequence_length), device=device) == 1).transpose(2, 3)

    # logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token)
    trg_mask = trg_padding_mask & trg_no_look_forward_mask  # final shape = (B, 1, T, T)
    num_trg_tokens = torch.sum(trg_padding_mask.long())

    return trg_mask, num_trg_tokens


def get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch, pad_token_id, device):
    src_mask, num_src_tokens = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id)
    trg_mask, num_trg_tokens = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id)

    return src_mask, trg_mask, num_src_tokens, num_trg_tokens


def get_src_and_trg_batches(token_ids_batch):
    src_token_ids_batch, trg_token_ids_batch = token_ids_batch.src, token_ids_batch.trg

    # Target input should be shifted by 1 compared to the target output tokens
    # Example: if we had a sentence like: [<s>,what,is,up,</s>] then to train the NMT model what we do is we pass
    # [<s>,what,is,up] to the input as set [what,is,up,</s>] as the expected output.
    trg_token_ids_batch_input = trg_token_ids_batch[:, :-1]

    # We reshape from (B, S) into (BxS, 1) as that's the the shape expected by LabelSmoothing which will produce
    # the shape (BxS, V) where V is the target vocab size which is the same shape as the one that comes out
    # from the transformer so we can directly pass them into the KL divergence loss
    trg_token_ids_batch_gt = trg_token_ids_batch[:, 1:].reshape(-1, 1)

    return src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt

## Caching mechanism

Optional section, just make sure to run the cell but it's not crucial for understanding how transformers work.

We'll need to add some caching mechanism otherwise running the training will always take lots of time before the actual training loop starts running.

In [11]:
#
# Caching mechanism datasets and functions (you don't need this but it makes things a lot faster!)
#


class FastTranslationDataset(Dataset):
    """
        After understanding the source code of torch text's IWSLT, TranslationDataset and Dataset I realized how I
        can make data preparation much faster (tokenization was taking a lot of time and there is no need to redo it
        every time) by using a simple caching mechanism.

        This dataset leverages that caching mechanism which reduced loading time from ~70s -> 2.5s (massive!)

    """

    @staticmethod
    def sort_key(ex):
        # What this does is basically it takes a 16-bit binary representation of lengths and interleaves them.
        # Example: lengths len(ex.src)=5 and len(ex.trg)=3 result in f(101, 011)=100111, 7 and 1 in f(111, 001)=101011
        # It's basically a heuristic that helps the BucketIterator sort bigger batches first
        return interleave_keys(len(ex.src), len(ex.trg))

    def __init__(self, cache_path, fields, **kwargs):
        # save_cache interleaves src and trg examples so here we read the cache file having that format in mind
        cached_data = [line.split() for line in open(cache_path, encoding='utf-8')]

        cached_data_src = cached_data[0::2]  # Even lines contain source examples
        cached_data_trg = cached_data[1::2]  # Odd lines contain target examples

        assert len(cached_data_src) == len(cached_data_trg), f'Source and target data should be of the same length.'

        examples = []
        src_dataset_total_number_of_tokens = 0
        trg_dataset_total_number_of_tokens = 0
        for src_tokenized_data, trg_tokenized_data in zip(cached_data_src, cached_data_trg):
            ex = Example()

            setattr(ex, 'src', src_tokenized_data)
            setattr(ex, 'trg', trg_tokenized_data)

            examples.append(ex)

            # Update the number of tokens
            src_dataset_total_number_of_tokens += len(src_tokenized_data)
            trg_dataset_total_number_of_tokens += len(trg_tokenized_data)

        # Print relevant information about the dataset (parsing the cache file name)
        filename_parts = os.path.split(cache_path)[1].split('_')
        src_language, trg_language = ('English', 'German') if filename_parts[0] == 'en' else ('German', 'English')
        dataset_name = 'IWSLT' if filename_parts[2] == 'iwslt' else 'WMT-14'
        dataset_type = 'train' if filename_parts[3] == 'train' else 'val'
        print(f'{dataset_type} dataset ({dataset_name}) has {src_dataset_total_number_of_tokens} tokens in the source language ({src_language}) corpus.')
        print(f'{dataset_type} dataset ({dataset_name}) has {trg_dataset_total_number_of_tokens} tokens in the target language ({trg_language}) corpus.')

        # Call the parent class Dataset's constructor
        super().__init__(examples, fields, **kwargs)


class DatasetWrapper(FastTranslationDataset):
    """
        Just a wrapper around the FastTranslationDataset.

    """

    @classmethod
    def get_train_and_val_datasets(cls, train_cache_path, val_cache_path, fields, **kwargs):

        train_dataset = cls(train_cache_path, fields, **kwargs)
        val_dataset = cls(val_cache_path, fields, **kwargs)

        return train_dataset, val_dataset


def save_cache(cache_path, dataset):
    with open(cache_path, 'w', encoding='utf-8') as cache_file:
        # Interleave source and target tokenized examples, source is on even lines, target is on odd lines
        for ex in dataset.examples:
            cache_file.write(' '.join(ex.src) + '\n')
            cache_file.write(' '.join(ex.trg) + '\n')

Phew, we've got our data pipeline 📜 and model 🦄 all ready to go!

Level 3 unlocked. 😍

**Let me again stress this - it takes time to understand this project don't get frustrated! ❤️ (it took me ~3 weeks to code this up!)**

Now there are a couple more things we need to take care of.

# Part 3: Training loop utilities

So we're done setting up the model and data! Now let's add a couple of other things we'll need in the training loop.

First, let's add a wrapper around Adam which will give it a custom learning rate.

<img src="data/readme_pics/custom_learning_rate_schedule.PNG" alt="Custom LR Adam"/>

It's also a good design to make it behave as PyTorch's native optimizers (by defining `step()` and `zero_grad()` functions)

In [12]:
class CustomLRAdamOptimizer:
    """
        Linear ramp learning rate for the warm-up number of steps and then start decaying
        according to the inverse square root law of the current training step number.

        Check out playground.py for visualization of the learning rate (visualize_custom_lr_adam).
    """

    def __init__(self, optimizer, model_dimension, num_of_warmup_steps):
        self.optimizer = optimizer
        self.model_size = model_dimension
        self.num_of_warmup_steps = num_of_warmup_steps

        self.current_step_number = 0

    def step(self):
        self.current_step_number += 1
        current_learning_rate = self.get_current_learning_rate()

        for p in self.optimizer.param_groups:
            p['lr'] = current_learning_rate

        self.optimizer.step()  # apply gradients

    # Check out the formula at Page 7, Chapter 5.3 "Optimizer" and playground.py for visualization
    def get_current_learning_rate(self):
        # For readability purpose
        step = self.current_step_number
        warmup = self.num_of_warmup_steps

        return self.model_size ** (-0.5) * min(step ** (-0.5), step * warmup ** (-1.5))

    def zero_grad(self):
        self.optimizer.zero_grad()

Label smoothing, what's that? Well you usually set your target vocabulary distribution to a `one-hot`. Meaning 1 position out of 30k (or whatever your vocab size is) is set to 1. probability and everything else to 0.

In `label smoothing` instead of placing 1. on that particular position you place say 0.9 and you evenly distribute the rest of the "probability mass" (0.1 in this case) over the other positions (that's visualized as a different shade of purple on the image above in a fictional vocab of size 4 - hence 4 columns):

<img src="data/readme_pics/label_smoothing.PNG" width="700" alt="Label smoothing example"/>

*Note: Pad token's distribution is set to all zeros as we don't want our model to predict those!*

In [13]:
class LabelSmoothingDistribution(nn.Module):
    """
        Instead of one-hot target distribution set the target word's probability to "confidence_value" (usually 0.9)
        and distribute the rest of the "smoothing_value" mass (usually 0.1) over the rest of the vocab.

        Check out playground.py for visualization of how the smooth target distribution looks like compared to one-hot.
    """

    def __init__(self, smoothing_value, pad_token_id, trg_vocab_size, device):
        assert 0.0 <= smoothing_value <= 1.0

        super(LabelSmoothingDistribution, self).__init__()

        self.confidence_value = 1.0 - smoothing_value
        self.smoothing_value = smoothing_value

        self.pad_token_id = pad_token_id
        self.trg_vocab_size = trg_vocab_size
        self.device = device

    def forward(self, trg_token_ids_batch):

        batch_size = trg_token_ids_batch.shape[0]
        smooth_target_distributions = torch.zeros((batch_size, self.trg_vocab_size), device=self.device)

        # -2 because we are not distributing the smoothing mass over the pad token index and over the ground truth index
        # those 2 values will be overwritten by the following 2 lines with confidence_value and 0 (for pad token index)
        smooth_target_distributions.fill_(self.smoothing_value / (self.trg_vocab_size - 2))

        smooth_target_distributions.scatter_(1, trg_token_ids_batch, self.confidence_value)
        smooth_target_distributions[:, self.pad_token_id] = 0.

        # If we had a pad token as a target we set the distribution to all 0s instead of smooth labeled distribution
        smooth_target_distributions.masked_fill_(trg_token_ids_batch == self.pad_token_id, 0.)

        return smooth_target_distributions

Finally we'll need a function that will help us evaluate our model in the training loop.

We'll be using the [BLEU metric](https://en.wikipedia.org/wiki/BLEU), which is a common metric for the machine translation task.

Here is an example of how using the Xavier initialization dramatically improved BLEU metric:

<img src="data/readme_pics/bleu_score_xavier_vs_default_pt_init.PNG" width="450"/>

You can see here 3 runs, the 2 lower ones used PyTorch default initialization (one used `mean` for KL divergence loss and the better one used `batchmean`), whereas the upper one used **Xavier uniform** initialization!

In [14]:
# Calculate the BLEU-4 score
def calculate_bleu_score(transformer, token_ids_loader, trg_field_processor):
    with torch.no_grad():
        pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN]

        gt_sentences_corpus = []
        predicted_sentences_corpus = []

        ts = time.time()
        for batch_idx, token_ids_batch in enumerate(token_ids_loader):
            src_token_ids_batch, trg_token_ids_batch = token_ids_batch.src, token_ids_batch.trg
            if batch_idx % 10 == 0:
                print(f'batch={batch_idx}, time elapsed = {time.time()-ts} seconds.')

            # Optimization - compute the source token representations only once
            src_mask, _ = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id)
            src_representations_batch = transformer.encode(src_token_ids_batch, src_mask)

            predicted_sentences = greedy_decoding(transformer, src_representations_batch, src_mask, trg_field_processor)
            predicted_sentences_corpus.extend(predicted_sentences)  # add them to the corpus of translations

            # Get the token and not id version of GT (ground-truth) sentences
            trg_token_ids_batch = trg_token_ids_batch.cpu().numpy()
            for target_sentence_ids in trg_token_ids_batch:
                target_sentence_tokens = [trg_field_processor.vocab.itos[id] for id in target_sentence_ids if id != pad_token_id]
                gt_sentences_corpus.append([target_sentence_tokens])  # add them to the corpus of GT translations

        bleu_score = corpus_bleu(gt_sentences_corpus, predicted_sentences_corpus)
        print(f'BLEU-4 corpus score = {bleu_score}, corpus length = {len(gt_sentences_corpus)}, time elapsed = {time.time()-ts} seconds.')
        return bleu_score

As you can see to calculate the BLEU score we need a decoding function.

The original paper used **beam search** but for now I've only got **greedy** implemented. <br/> It's not crucial to have beam search to understand how transformers work so this is fine.

*Note: The BLEU metric will give us a pessimistic estimate of our model's perf as the beam search usually gives better results.*

In [15]:
def greedy_decoding(baseline_transformer, src_representations_batch, src_mask, trg_field_processor, max_target_tokens=100):
    """
    Supports batch (decode multiple source sentences) greedy decoding.

    Decoding could be further optimized to cache old token activations because they can't look ahead and so
    adding a newly predicted token won't change old token's activations.

    Example: we input <s> and do a forward pass. We get intermediate activations for <s> and at the output at position
    0, after the doing linear layer we get e.g. token <I>. Now we input <s>,<I> but <s>'s activations will remain
    the same. Similarly say we now got <am> at output position 1, in the next step we input <s>,<I>,<am> and so <I>'s
    activations will remain the same as it only looks at/attends to itself and to <s> and so forth.

    """

    device = next(baseline_transformer.parameters()).device
    pad_token_id = trg_field_processor.vocab.stoi[PAD_TOKEN]

    # Initial prompt is the beginning/start of the sentence token. Make it compatible shape with source batch => (B,1)
    target_sentences_tokens = [[BOS_TOKEN] for _ in range(src_representations_batch.shape[0])]
    trg_token_ids_batch = torch.tensor([[trg_field_processor.vocab.stoi[tokens[0]]] for tokens in target_sentences_tokens], device=device)

    # Set to true for a particular target sentence once it reaches the EOS (end-of-sentence) token
    is_decoded = [False] * src_representations_batch.shape[0]

    while True:
        trg_mask, _ = get_masks_and_count_tokens_trg(trg_token_ids_batch, pad_token_id)
        # Shape = (B*T, V) where T is the current token-sequence length and V target vocab size
        predicted_log_distributions = baseline_transformer.decode(trg_token_ids_batch, src_representations_batch, trg_mask, src_mask)

        # Extract only the indices of last token for every target sentence (we take every T-th token)
        num_of_trg_tokens = len(target_sentences_tokens[0])
        predicted_log_distributions = predicted_log_distributions[num_of_trg_tokens-1::num_of_trg_tokens]

        # This is the "greedy" part of the greedy decoding:
        # We find indices of the highest probability target tokens and discard every other possibility
        most_probable_last_token_indices = torch.argmax(predicted_log_distributions, dim=-1).cpu().numpy()

        # Find target tokens associated with these indices
        predicted_words = [trg_field_processor.vocab.itos[index] for index in most_probable_last_token_indices]

        for idx, predicted_word in enumerate(predicted_words):
            target_sentences_tokens[idx].append(predicted_word)

            if predicted_word == EOS_TOKEN:  # once we find EOS token for a particular sentence we flag it
                is_decoded[idx] = True

        if all(is_decoded) or num_of_trg_tokens == max_target_tokens:
            break

        # Prepare the input for the next iteration (merge old token ids with the new column of most probable token ids)
        trg_token_ids_batch = torch.cat((trg_token_ids_batch, torch.unsqueeze(torch.tensor(most_probable_last_token_indices, device=device), 1)), 1)

    # Post process the sentences - remove everything after the EOS token
    target_sentences_tokens_post = []
    for target_sentence_tokens in target_sentences_tokens:
        try:
            target_index = target_sentence_tokens.index(EOS_TOKEN) + 1
        except:
            target_index = None

        target_sentence_tokens = target_sentence_tokens[:target_index]
        target_sentences_tokens_post.append(target_sentence_tokens)

    return target_sentences_tokens_post

A couple of more functions which will keep things nice and clean. 

Feel free to ignore these as they are not crucial for understanding how the transformer works.

In [16]:
def get_available_binary_name():
    prefix = 'transformer'

    def valid_binary_name(binary_name):
        # First time you see raw f-string? Don't worry the only trick is to double the brackets.
        pattern = re.compile(rf'{prefix}_[0-9]{{6}}\.pth')
        return re.fullmatch(pattern, binary_name) is not None

    # Just list the existing binaries so that we don't overwrite them but write to a new one
    valid_binary_names = list(filter(valid_binary_name, os.listdir(BINARIES_PATH)))
    if len(valid_binary_names) > 0:
        last_binary_name = sorted(valid_binary_names)[-1]
        new_suffix = int(last_binary_name.split('.')[0][-6:]) + 1  # increment by 1
        return f'{prefix}_{str(new_suffix).zfill(6)}.pth'
    else:
        return f'{prefix}_000000.pth'


def get_training_state(training_config, model):
    training_state = {
        # "commit_hash": git.Repo(search_parent_directories=True).head.object.hexsha,
        "dataset_name": training_config['dataset_name'],
        "language_direction": training_config['language_direction'],

        "num_of_epochs": training_config['num_of_epochs'],
        "batch_size": training_config['batch_size'],

        "state_dict": model.state_dict()
    }

    return training_state


def print_model_metadata(training_state):
    header = f'\n{"*"*5} Model training metadata: {"*"*5}'
    print(header)

    for key, value in training_state.items():
        if key != 'state_dict':  # don't print state_dict it's a bunch of numbers...
            if key == 'language_direction':  # convert into human readable format
                value = 'English to German' if value == 'E2G' else 'German to English'
            print(f'{key}: {value}')
    print(f'{"*" * len(header)}\n')

Level 4 unlocked. 😍

**Let me again^2 stress this - it takes time to understand this project don't get frustrated! ❤️ (it took me ~3 weeks to code this up!)**

# Part 4: Training loop

Now let us combine all of the above sections together and finally train our transformer!

Here are the functions we'll be using:

In [None]:
# Global vars for logging purposes
num_of_trg_tokens_processed = 0
bleu_scores = []
global_train_step, global_val_step = [0, 0]
writer = SummaryWriter()  # (tensorboard) writer will output to ./runs/ directory by default


# Simple decorator function so that I don't have to pass these arguments every time I call get_train_val_loop
def get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time_start):

    def train_val_loop(is_train, token_ids_loader, epoch):
        global num_of_trg_tokens_processed, global_train_step, global_val_step, writer

        if is_train:
            baseline_transformer.train()
        else:
            baseline_transformer.eval()

        device = next(baseline_transformer.parameters()).device

        #
        # Main loop - start of the CORE PART
        #
        for batch_idx, token_ids_batch in enumerate(token_ids_loader):
            src_token_ids_batch, trg_token_ids_batch_input, trg_token_ids_batch_gt = get_src_and_trg_batches(token_ids_batch)
            src_mask, trg_mask, num_src_tokens, num_trg_tokens = get_masks_and_count_tokens(src_token_ids_batch, trg_token_ids_batch_input, pad_token_id, device)

            # log because the KL loss expects log probabilities (just an implementation detail)
            predicted_log_distributions = baseline_transformer(src_token_ids_batch, trg_token_ids_batch_input, src_mask, trg_mask)
            smooth_target_distributions = label_smoothing(trg_token_ids_batch_gt)  # these are regular probabilities

            if is_train:
                custom_lr_optimizer.zero_grad()  # clean the trainable weights gradients in the computational graph

            loss = kl_div_loss(predicted_log_distributions, smooth_target_distributions)

            if is_train:
                loss.backward()  # compute the gradients for every trainable weight in the computational graph
                custom_lr_optimizer.step()  # apply the gradients to weights

            # End of CORE PART

            #
            # Logging and metrics
            #

            if is_train:
                global_train_step += 1
                num_of_trg_tokens_processed += num_trg_tokens

                if training_config['enable_tensorboard']:
                    writer.add_scalar('training_loss', loss.item(), global_train_step)

                if training_config['console_log_freq'] is not None and batch_idx % training_config['console_log_freq'] == 0:
                    print(f'Transformer training: time elapsed= {(time.time() - time_start):.2f} [s] '
                          f'| epoch={epoch + 1} | batch= {batch_idx + 1} '
                          f'| target tokens/batch= {num_of_trg_tokens_processed / training_config["console_log_freq"]}')

                    num_of_trg_tokens_processed = 0

                # Save model checkpoint
                if training_config['checkpoint_freq'] is not None and (epoch + 1) % training_config['checkpoint_freq'] == 0 and batch_idx == 0:
                    ckpt_model_name = f"transformer_ckpt_epoch_{epoch + 1}.pth"
                    torch.save(get_training_state(training_config, baseline_transformer), os.path.join(CHECKPOINTS_PATH, ckpt_model_name))
            else:
                global_val_step += 1

                if training_config['enable_tensorboard']:
                    writer.add_scalar('val_loss', loss.item(), global_val_step)

    return train_val_loop


def train_transformer(training_config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU, I hope so!

    # Step 1: Prepare data loaders
    train_token_ids_loader, val_token_ids_loader, src_field_processor, trg_field_processor = get_data_loaders(
        training_config['dataset_path'],
        training_config['language_direction'],
        training_config['dataset_name'],
        training_config['batch_size'],
        device)

    pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN]  # pad token id is the same for target as well
    src_vocab_size = len(src_field_processor.vocab)
    trg_vocab_size = len(trg_field_processor.vocab)

    # Step 2: Prepare the model (original transformer) and push to GPU
    baseline_transformer = Transformer(
        model_dimension=BASELINE_MODEL_DIMENSION,
        src_vocab_size=src_vocab_size,
        trg_vocab_size=trg_vocab_size,
        number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS,
        number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
        dropout_probability=BASELINE_MODEL_DROPOUT_PROB
    ).to(device)

    # Step 3: Prepare other training related utilities
    kl_div_loss = nn.KLDivLoss(reduction='batchmean')  # gives better BLEU score than "mean"

    # Makes smooth target distributions as opposed to conventional one-hot distributions
    # My feeling is that this is a really dummy and arbitrary heuristic but time will tell.
    label_smoothing = LabelSmoothingDistribution(BASELINE_MODEL_LABEL_SMOOTHING_VALUE, pad_token_id, trg_vocab_size, device)

    # Check out playground.py for an intuitive visualization of how the LR changes with time/training steps, easy stuff.
    custom_lr_optimizer = CustomLRAdamOptimizer(
                Adam(baseline_transformer.parameters(), betas=(0.9, 0.98), eps=1e-9),
                BASELINE_MODEL_DIMENSION,
                training_config['num_warmup_steps']
            )

    # The decorator function makes things cleaner since there is a lot of redundancy between the train and val loops
    train_val_loop = get_train_val_loop(baseline_transformer, custom_lr_optimizer, kl_div_loss, label_smoothing, pad_token_id, time.time())

    # Step 4: Start the training
    for epoch in range(training_config['num_of_epochs']):
        # Training loop
        train_val_loop(is_train=True, token_ids_loader=train_token_ids_loader, epoch=epoch)

        # Validation loop
        with torch.no_grad():
            train_val_loop(is_train=False, token_ids_loader=val_token_ids_loader, epoch=epoch)

            bleu_score = calculate_bleu_score(baseline_transformer, val_token_ids_loader, trg_field_processor)
            if training_config['enable_tensorboard']:
                writer.add_scalar('bleu_score', bleu_score, epoch)

    # Save the latest transformer in the binaries directory
    torch.save(get_training_state(training_config, baseline_transformer), os.path.join(BINARIES_PATH, get_available_binary_name()))

## Let's train it!

And finally run this cell to start training your transformer! 

Feel free to skip it and proceed to the next section where I've added the option of automatic download of pretrained models (in case you don't want or you don't have enough resources to train it yourself)

In [None]:
#
# Fixed args - don't change these unless you have a good reason
#
num_warmup_steps = 4000

#
# Modifiable args - feel free to play with these (only small subset is exposed by design to avoid cluttering)
#
parser = argparse.ArgumentParser()
# According to the paper I infered that the baseline was trained for ~19 epochs on the WMT-14 dataset and I got
# nice returns up to epoch ~20 on IWSLT as well (nice round number)
parser.add_argument("--num_of_epochs", type=int, help="number of training epochs", default=20)
# You should adjust this for your particular machine (I have RTX 2080 with 8 GBs of VRAM so 1500 fits nicely!)
parser.add_argument("--batch_size", type=int, help="target number of tokens in a src/trg batch", default=500)

# Data related args
parser.add_argument("--dataset_name", choices=[el.name for el in DatasetType], help='which dataset to use for training', default=DatasetType.IWSLT.name)
parser.add_argument("--language_direction", choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.E2G.name)
parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_DIR_PATH)

# Logging/debugging/checkpoint related (helps a lot with experimentation)
parser.add_argument("--enable_tensorboard", type=bool, help="enable tensorboard logging", default=True)
parser.add_argument("--console_log_freq", type=int, help="log to output console (batch) freq", default=10)
parser.add_argument("--checkpoint_freq", type=int, help="checkpoint model saving (epoch) freq", default=1)
args = parser.parse_args("")

# Wrapping training configuration into a dictionary
training_config = dict()
for arg in vars(args):
    training_config[arg] = getattr(args, arg)
training_config['num_warmup_steps'] = num_warmup_steps

# Train the original transformer model
train_transformer(training_config)

Level 5 unlocked!!! 😍

**Let me again^3 stress this - it takes time to understand this project don't get frustrated! ❤️ (it took me ~3 weeks to code this up!)**

# Part 5: Let's translate! (machine translation example)

Now for the interesting part (well depends whom you ask!)! Let's translate between **German** and **English**.

But before we do that we'll want to add functions which will help us visualize the attention mechanism!

As an example here are the attentions I get for the input sentence `Ich bin ein guter Mensch, denke ich.`

These belong to layer 6 of the encoder. You can see all of the 8 multi-head attention heads.

<img src="data/readme_pics/attention_enc_self.PNG" width="850"/>

These help us understand as to which part of the source/target sentence is the transformer attending to.

In [17]:
def plot_attention_heatmap(data, x, y, head_id, ax):
    seaborn.heatmap(data, xticklabels=x, yticklabels=y, square=True, vmin=0.0, vmax=1.0, cbar=False, annot=True, fmt=".2f", ax=ax)
    ax.set_title(f'MHA head id = {head_id}')


def visualize_attention_helper(attention_weights, source_sentence_tokens=None, target_sentence_tokens=None, title=''):
    num_columns = 4
    num_rows = 2
    fig, axs = plt.subplots(num_rows, num_columns, figsize=(20, 10))  # prepare the figure and axes

    assert source_sentence_tokens is not None or target_sentence_tokens is not None, \
        f'Either source or target sentence must be passed in.'

    target_sentence_tokens = source_sentence_tokens if target_sentence_tokens is None else target_sentence_tokens
    source_sentence_tokens = target_sentence_tokens if source_sentence_tokens is None else source_sentence_tokens

    for head_id, head_attention_weights in enumerate(attention_weights):
        row_index = int(head_id / num_columns)
        column_index = head_id % num_columns
        plot_attention_heatmap(head_attention_weights, source_sentence_tokens, target_sentence_tokens if head_id % num_columns == 0 else [], head_id, axs[row_index, column_index])

    fig.suptitle(title)
    plt.show()


def visualize_attention(baseline_transformer, source_sentence_tokens, target_sentence_tokens):
    encoder = baseline_transformer.encoder
    decoder = baseline_transformer.decoder

    # Remove the end of sentence token </s> as we never attend to it, it's produced at the output and we stop
    target_sentence_tokens = target_sentence_tokens[0][:-1]

    # Visualize encoder attention weights
    for layer_id, encoder_layer in enumerate(encoder.encoder_layers):
        mha = encoder_layer.multi_headed_attention  # Every encoder layer has 1 MHA module

        # attention_weights shape = (B, NH, S, S), extract 0th batch and loop over NH (number of heads) MHA heads
        # S stands for maximum source token-sequence length
        attention_weights = mha.attention_weights.cpu().numpy()[0]

        title = f'Encoder layer {layer_id + 1}'
        visualize_attention_helper(attention_weights, source_sentence_tokens, title=title)

    # Visualize decoder attention weights
    for layer_id, decoder_layer in enumerate(decoder.decoder_layers):
        mha_trg = decoder_layer.trg_multi_headed_attention  # Extract the self-attention MHA
        mha_src = decoder_layer.src_multi_headed_attention  # Extract the source attending MHA

        # attention_weights shape = (B, NH, T, T), T stands for maximum target token-sequence length
        attention_weights_trg = mha_trg.attention_weights.cpu().numpy()[0]
        # shape = (B, NH, T, S), target token representations create queries and keys/values come from the encoder
        attention_weights_src = mha_src.attention_weights.cpu().numpy()[0]

        title = f'Decoder layer {layer_id + 1}, self-attention MHA'
        visualize_attention_helper(attention_weights_trg, target_sentence_tokens=target_sentence_tokens, title=title)

        title = f'Decoder layer {layer_id + 1}, source-attending MHA'
        visualize_attention_helper(attention_weights_src, source_sentence_tokens, target_sentence_tokens, title)

## Download pretrained transformers automatically

In case you skipped training your own model this code will automatically download pretrained models for you once you start the next cell (`translate_a_single_sentence()` function)

In [18]:
IWSLT_ENGLISH_TO_GERMAN_MODEL_URL = r'https://www.dropbox.com/s/a6pfo6t9m2dh1jq/iwslt_e2g.pth?dl=1'
IWSLT_GERMAN_TO_ENGLISH_MODEL_URL = r'https://www.dropbox.com/s/dgcd4xhwig7ygqd/iwslt_g2e.pth?dl=1'


# Not yet trained
WMT14_ENGLISH_TO_GERMAN_MODEL_URL = None
WMT14_GERMAN_TO_ENGLISH_MODEL_URL = None


DOWNLOAD_DICT = {
    'iwslt_e2g': IWSLT_ENGLISH_TO_GERMAN_MODEL_URL,
    'iwslt_g2e': IWSLT_GERMAN_TO_ENGLISH_MODEL_URL,
    'wmt14_e2g': WMT14_ENGLISH_TO_GERMAN_MODEL_URL,
    'wmt14_g2e': WMT14_GERMAN_TO_ENGLISH_MODEL_URL
}


download_choices = list(DOWNLOAD_DICT.keys())


def download_models(translation_config):
    # Step 1: Form the key
    language_direction = translation_config['language_direction'].lower()
    dataset_name = translation_config['dataset_name'].lower()
    key = f'{dataset_name}_{language_direction}'

    # Step 2: Check whether this model already exists
    model_name = f'{key}.pth'
    model_path = os.path.join(BINARIES_PATH, model_name)
    if os.path.exists(model_path):
        print(f'No need to download, found model {model_path} that was trained on {dataset_name} for language direction {language_direction}.')
        return model_path

    # Step 3: Download the resource to local filesystem
    remote_resource_path = DOWNLOAD_DICT[key]
    if remote_resource_path is None:  # handle models which I've not provided URLs for yet
        print(f'No model found that was trained on {dataset_name} for language direction {language_direction}.')
        exit(0)

    print(f'Downloading from {remote_resource_path}. This may take a while.')
    download_url_to_file(remote_resource_path, model_path)

    return model_path

## Neural Machine Translation ❤️ 

Finally this is where the magic happens!

In [19]:
# Super easy to add translation for a batch of sentences passed as a .txt file for example
def translate_a_single_sentence(translation_config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # checking whether you have a GPU

    # Step 1: Prepare the field processor (tokenizer, numericalizer)
    _, _, src_field_processor, trg_field_processor = get_datasets_and_vocabs(
        translation_config['dataset_path'],
        translation_config['language_direction'],
        translation_config['dataset_name'] == DatasetType.IWSLT.name
    )
    assert src_field_processor.vocab.stoi[PAD_TOKEN] == trg_field_processor.vocab.stoi[PAD_TOKEN]
    pad_token_id = src_field_processor.vocab.stoi[PAD_TOKEN]  # needed for constructing masks

    # Step 2: Prepare the model
    baseline_transformer = Transformer(
        model_dimension=BASELINE_MODEL_DIMENSION,
        src_vocab_size=len(src_field_processor.vocab),
        trg_vocab_size=len(trg_field_processor.vocab),
        number_of_heads=BASELINE_MODEL_NUMBER_OF_HEADS,
        number_of_layers=BASELINE_MODEL_NUMBER_OF_LAYERS,
        dropout_probability=BASELINE_MODEL_DROPOUT_PROB,
        log_attention_weights=True
    ).to(device)

    model_path = os.path.join(BINARIES_PATH, translation_config['model_name'])
    if not os.path.exists(model_path):
        print(f'Model {model_path} does not exist, attempting to download.')
        model_path = download_models(translation_config)

    model_state = torch.load(model_path)
    print_model_metadata(model_state)
    baseline_transformer.load_state_dict(model_state["state_dict"], strict=True)
    baseline_transformer.eval()

    # Step 3: Prepare the input sentence
    source_sentence = translation_config['source_sentence']
    ex = Example.fromlist([source_sentence], fields=[('src', src_field_processor)])  # tokenize the sentence

    source_sentence_tokens = ex.src
    print(f'Source sentence tokens = {source_sentence_tokens}')

    # Numericalize and convert to cuda tensor
    src_token_ids_batch = src_field_processor.process([source_sentence_tokens], device)

    with torch.no_grad():
        # Step 4: Optimization - compute the source token representations only once
        src_mask, _ = get_masks_and_count_tokens_src(src_token_ids_batch, pad_token_id)
        src_representations_batch = baseline_transformer.encode(src_token_ids_batch, src_mask)

        # Step 5: Decoding process
        target_sentence_tokens = greedy_decoding(baseline_transformer, src_representations_batch, src_mask, trg_field_processor)
        print(f'Translation | Target sentence tokens = {target_sentence_tokens}')

        # Step 6: Potentially visualize the encoder/decoder attention weights
        if translation_config['visualize_attention']:
            visualize_attention(baseline_transformer, source_sentence_tokens, target_sentence_tokens)


#
# modifiable args - feel free to play with these (only small subset is exposed by design to avoid cluttering)
#
parser = argparse.ArgumentParser()
parser.add_argument("--source_sentence", type=str, help="source sentence to translate into target", default="How are you doing today?")
parser.add_argument("--model_name", type=str, help="transformer model name", default=r'iwslt_e2g.pth')

# Keep these 2 in sync with the model you pick via model_name
parser.add_argument("--dataset_name", type=str, choices=['IWSLT', 'WMT14'], help='which dataset to use for training', default=DatasetType.IWSLT.name)
parser.add_argument("--language_direction", type=str, choices=[el.name for el in LanguageDirection], help='which direction to translate', default=LanguageDirection.E2G.name)

# Cache files and datasets are downloaded here during training, keep them in sync for speed
parser.add_argument("--dataset_path", type=str, help='download dataset to this path', default=DATA_DIR_PATH)

# Decoding related args
parser.add_argument("--beam_size", type=int, help="used only in case decoding method is chosen", default=4)
parser.add_argument("--length_penalty_coefficient", type=int, help="length penalty for the beam search", default=0.6)

parser.add_argument("--visualize_attention", type=bool, help="should visualize encoder/decoder attention", default=False)
args = parser.parse_args("")

# Wrapping training configuration into a dictionary
translation_config = dict()
for arg in vars(args):
    translation_config[arg] = getattr(args, arg)

# Translate the given source sentence
translate_a_single_sentence(translation_config)

train dataset (IWSLT) has 3937527 tokens in the source language (English) corpus.
train dataset (IWSLT) has 3634135 tokens in the target language (German) corpus.
val dataset (IWSLT) has 20911 tokens in the source language (English) corpus.
val dataset (IWSLT) has 19540 tokens in the target language (German) corpus.
Time it took to prepare the data: 2.328090 seconds.

***** Model training metadata: *****
dataset_name: IWSLT
language_direction: English to German
num_of_epochs: 20
batch_size: 1500
*************************************

Source sentence tokens = ['How', 'are', 'you', 'doing', 'today', '?']
Translation | Target sentence tokens = [['<s>', 'Wie', 'geht', 'es', 'Ihnen', 'heute', '?', '</s>']]


Phew!!! That was a mouthful! If you stayed with me until here, **congrats!** (achievement unlocked - transformer master 😍)

Take your time to analyze this notebook. This is not a toy project, it took me ~3 weeks to finish it <br/> 
so don't expect to understand everything in 30 minutes unless you're really familiar with most of the concepts mentioned here.

And last but not least!

# Connect with me

I share lots of useful (I hope so at least!) content on LinkedIn, Twitter, YouTube and Medium. <br/>
So feel free to connect with me there:
1. My [LinkedIn](https://www.linkedin.com/in/aleksagordic) and [Twitter](https://twitter.com/gordic_aleksa) profiles
2. My YouTube channel - [The AI Epiphany](https://www.youtube.com/c/TheAiEpiphany)
3. My [Medium](https://gordicaleksa.medium.com/) profile

Also do drop me a message if you found this useful or if you think I could've done something better! <br/>
I always like getting some feedback on the work I do.

If you notice some bugs/errors feel free to **open up an issue** or even **submit a pull request**.

# Additional resources

I highly recommend [this blog by Jay Alammar](https://jalammar.github.io/illustrated-transformer/) to everyone learning about transformers.


Also do check out my [YouTube video](https://www.youtube.com/watch?v=bvBK-coXf9I) on how I structured my transformers learning journey!

<a href="https://www.youtube.com/watch?v=bvBK-coXf9I" target="_blank"><img src="https://img.youtube.com/vi/bvBK-coXf9I/0.jpg" 
alt="learning about transformers" width="480" align="left" height="360" border="10" /></a>