### Implementing a Transformer from Scratch
Based on `The Annotated Transformer` by Harvard NLP. Let's do this ...

In [2]:
import os
from os.path import exists
import math
import spacy
import copy
import time
import GPUtil
import warnings

import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp
import torchtext.datasets as datasets

from torch.nn.functional import log_softmax, pad
from torch.optim.lr_scheduler import LambdaLR
from torchtext.data.functional import to_map_style_dataset
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

import pandas as pd
import altair as alt

In [3]:
warnings.filterwarnings("ignore")
RUN_EXAMPLES = True

In [4]:
# some helper functions
def is_interactive_notebook():
    return '__name__' == '__main__'

def show_example(fn, args=[]):
    if __name__ == '__main__' and RUN_EXAMPLES:
        return fn(*args)
    
def execute_example(fn, args=[]):
    if __name__ == '__main__' and RUN_EXAMPLES:
        fn(*args)

class DummyOptimizer(torch.optim.Optimizer):
    def __init__(self):
        self.param_groups = [{'lr': 0}]
        None

    def step(self):
        None

    def zero_grad(self, set_to_none=False):
        None

class DummyScheduler:
    def step(self):
        None

#### Model Architecture
The transformer follows an encoder-decoder architecture

<center>
    <img src='transformer.png' width='400'>
</center>

In [5]:
class EncoderDecoder(nn.Module):
    '''
    A standard Encoder-Decoder architecture. Base for this and many other models
    '''

    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked source and target sequences"
        return self.decode(self.encode(src, src_mask), src_mask, tgt, tgt_mask)
    
    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask)

In [6]:
class Generator(nn.Module):
    '''
    Define standard linear + softmax generation step
    '''
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)
    
    def forward(self, x):
        return log_softmax(self.proj(x), dim=-1)

#### Encoder and Decoder Stacks
##### Encoder
Composed of a stack of $N=6$ identical layers

In [7]:
def clones(module, N):
    '''Produce N identical layers'''
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [8]:
class Encoder(nn.Module):
    '''Core encoder is a stack of N layers'''

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = nn.LayerNorm(layer.size)

    def forward(self, x, mask):
        '''Pass the input (and mask) through each layer in turn'''
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

There is a residual connection around each of the two sub-layers in the encoder, followed by layer normalisation. The normalised input is given by:

$$\hat{x_i} = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}}$$

First, we implement a layer normalisation module.

In [9]:
class LayerNorm(nn.Module):
    '''Construct a layer normalization module'''
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x:torch.Tensor):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

The output of each sublayer is $\mathbf{LayerNorm}(x + \mathbf{Sublayer}(x))$, where $\mathbf{Sublayer}(x)$ is the function implemented by the sub-layer itself. We then apply dropout to the output of each sublayer before it is added to the sub-layer input and normalised.

To facilitate residual connections, all sub-layers in the model, including embedding layers, produce outputs of dimension $d_{model}=512$.

In [10]:
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    NOTE: For code simplicity, the norm is first as opposed to last
    """

    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        """Apply residual connection to any sublayer with the same size."""
        return x + self.dropout(sublayer(self.norm(x)))

Each sublayer has two sub-layers. The first is a multi-head self-attention mechanism, and the second is a position-wise full connected feed-forward network.

In [11]:
class EncoderLayer(nn.Module):
    '''Encoder is made of self attention and feed forward portions'''

    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 2)

    def forward(self, x, mask):
        '''Follow the architecture figure (left) for connections'''
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

##### Decoder
Composed of a stack of $N=6$ identical layers

In [12]:
class Decoder(nn.Module):
    '''Generic N layer decoder with masking.'''

    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, memory, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

Apart from the two sub-layers in each encoder layer, the encoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. There are also residual connections around each sub-layer, followed by layer normalisation.

In [13]:
class DecoderLayer(nn.Module):
    '''
    Decoder is made of self attention, source attention, and feed forward
    '''

    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)

    def forward(self, x, memory, src_mask, tgt_mask):
        '''Follow the architecture figure (right) for connections'''
        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)

We need to ensure that the predictions for position $i$ can depend only on the known outputs at positions less than $i$. To achieve this, the self-attention sub-layer in the decoder stack is modified to prevent positions from attending to subsequent positions. This is also aided by the fact that the output embeddings are offset by one position.

In [14]:
def subsequent_mask(size):
    '''Mask out subsequent positions.'''
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return subsequent_mask == 0

In [15]:
subsequent_mask(5)

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])

During training, words are blocked from attending to future words.
Let's see an illustration of the masking ...

In [16]:
def example_mask():
    LS_data = pd.concat([pd.DataFrame( {'Subsequent Mask': subsequent_mask(20)[0][x, y].flatten(), 'Window': y,  'Masking': x}) for y in range(20) for x in range(20)])

    return (
        alt.Chart(LS_data)
        .mark_rect()
        .properties(height=250, width=250)
        .encode(
            alt.X('Window:O'),
            alt.Y('Masking:O'),
            alt.Color('Subsequent Mask:Q', scale=alt.Scale(scheme='viridis')),
        )
        .interactive()
    )

show_example(example_mask)

##### Attention
An attention function maps a query and a set of key-value pairs to an output, where the query, keys, and values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

It is implemented as a **Scaled Dot-Product Attention** where the input consists of queries and keys of dimension $d_k$, values of dimension $d_v$. We compute the dot products of the query with all keys, divide each by $\sqrt{d_k}$, and apply a softmax function to obtain the weights on the values.

In practice, attention is computed on a set of queries, packed together into a matrix $Q$, simultaneously. The keys and values are also packed together into matrices $K$ and $V$. The output is calculated as:

\begin{equation}
    \mathbf{Attention}(Q,K,V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d_k}} \right)V
\end{equation}

In [None]:
def attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, mask=None, dropout=None):
    '''
    Compute "scaled dot-product attention"

    -------------------------------------------------------------------------------------------- #
    # NOTES:
    - Each query, key, and value has shape (batch_size, h, 1, d_k) where h is the number of heads
    - The operation key.transpose(-2, -1) changes the shape of key to (batch_size, h, d_k, 1)
    - So, scores will have shape (batch_size, h, 1, 1)
    
    -------------------------------------------------------------------------------------------- #'
    '''

    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None:
        p_attn = dropout(p_attn)
    
    return torch.matmul(p_attn, value), p_attn

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. When using a single attention head, averaging inhibits this.

\begin{align}
    \mathbf{MultiHead}(Q, K, V) &= \mathbf{Concat}(head_1, \cdots, head_h)W^O \\
    \mathbf{where} &= \mathbf{Attention}(QW^Q_i, KW^K_i, VW^V_i)
\end{align}

where the projections are parameter matrices $W^Q_i \in \mathbb{R}^{d_{model} \times d_k}$, $W^K_i \in \mathbb{R}^{d_{model} \times d_k}$, $W^V_i \in \mathbb{R}^{d_{model} \times d_v}$, and $W^O_i \in \mathbb{R}^{hd_v \times d_{model}}$

In the transformer, we employ $h=8$ parallel attention layers/heads. For each of these, we use $d_k = d_v = d_{model}/h=64$. Due to this reduced dimensionality of each head, the total computational cost is similar to that of single-head attention with full dimensionality.

<center>
    <img src='multihead_attention.jpg' width=500>
</center>

In [217]:
class MultiheadedAttention(nn.Module):
    '''Implements multi-headed attention'''
    
    def __init__(self, h, d_model, dropout=0.1):
        '''Take in the model size and number of heads'''
        super(MultiheadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, mask:torch.Tensor=None):
        '''
        #### --------------------------------------------------------------------------------------------
        # NOTES:
        - Each query, key, and value begins with the shape (nbatches, d_model)
        - Each linear layer has shape (d_model, d_model), which is like concatenating together h matrices of shape (d_model, d_k)
        - The view operation reshapes the tensor to (nbatches, h, 1, d_k). Recall that d_k = d_model // h
        - The transpose(1, 2) operation swaps the second and third dimensions
        #### --------------------------------------------------------------------------------------------'
        '''

        if mask is not None:
            mask = mask.unsqueeze(1) # same mask applied to all h heads
        nbatches = query.size(0)

        # 1) Do all the linear projections in batch from d_model => h x d_k
        query, key, value = [
            linear(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
            for linear, x in zip(self.linears, (query, key, value))
        ]

        
        #2) Apply attention on all the projected vectors in batch
        x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)

        # 3) "Concat" using a view and apply a final linear
        x = (x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k))        

        del query
        del key
        del value
        
        # this is the top-most linear layer before the output of the multi-headed attention
        return self.linears[-1](x)
