# The Annotated Transformer
This is a bit redudant, but I am now going over The Annotated Transformer paper to dive deeply into the transformer architecture step by step. However, in this case we will put everything in this notebook and adjust things. This is mainly to just solidify my understanding on how the Transformer model works underneath the hood.

In [2]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

## Model Architecture
Most competitive neural sequence transduction models have an encoder-decoder structure. Here the encoder maps an input sequence of symbol representations to a sequence of continuous representations. Given this continuous representation, the decoder then generates an output sequence of symboles one element at a time. At each step the model is auto-regressive, consuming the previously generated symbols as additional input when genenerating the next.

In [242]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        raise NotImplementedError
    
class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.d_model, self.vocab = d_model, vocab
        self.linear_proj = nn.Linear(d_model, vocab)
    def forward(self, x): 
        x = self.linear_proj(x)
        return log_softmax(x, dim=-1)
        
    
class EncoderDecoder(nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, src_emb: nn.Module, tgt_emb: nn.Module, generator: Generator):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_emb = src_emb
        self.tgt_emb = tgt_emb
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        z = self.encoder(src, src_mask)
        return self.decode(z, tgt, src_mask, tgt_mask)

    def encode(self, src, src_mask):
        embeds = self.src_emb(src)
        return self.encoder(embeds, src_mask)
    
    def decode(self, z, src_mask, tgt, tgt_mask):
        embeds = self.tgt_emb(tgt)
        return self.decoder(embeds, z, src_mask, tgt_mask)

### Encoder and Decoder Stacks
The encoder is composed of a stack of N = 6 indentical layers

In [243]:
def clones(module, N): return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

class LayerNorm(nn.Module):
    def __init__(self, features, eps=1e-6):
        super().__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

Now we will create a ```sublayer```, the output of this layer is always ```layernorm(x + sublayer(x))``` where ```sublayer(x)``` is the function implemented by the sub-layer itself. We will also apply dropout. 

The reason we design it this way is because both the encoder and decoder have individual modules that compute some function and then apply a residual and layer norm. 

In [244]:
class SubLayerConnection(nn.Module):
    def __init__(self, size, dropout):
        super().__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

Each layer has two sub-layers. The first is a multi-head selt-attention mechanism, and the second is a simple, position-wise fully connected feed-foward network.

In [245]:
class EncoderLayer(nn.Module):
    def __init__(self, size, self_attn, feed_forward, dropout):
        super().__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout), 2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))
        return self.sublayer[1](x, self.feed_forward)

The decoder is also composed on a stack of N = 6 identical layers

In [246]:
class Decoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, z, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, z, src_mask, tgt_mask)
        return self.norm(x)

In addition to the two sub-layers in each encoder layer, the decoder inserts a third sub-layer, which performs multi-head attention over the output of the encoder stack. Similar to the encoder, we employ residual connections around each of the sub-layers, followed by layer normalization

In [247]:
class DecoderLayer(nn.Module):
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super().__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout), 3)

    def forward(self, x, z, src_mask, tgt_mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, z, z, src_mask))
        return self.sublayer[2](x, self.feed_forward)

We also modify the self-attention sub-layer  in the decoder stack to prevent positions from attending to subsequent positions. This masking, combined with the fact that the output embeddings are offset by one position, ensures that the predictions for position i can depend only on the known outputs at positions less than i

In [248]:
size = 512
attn_shape = (1, size, size)
subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
subsequent_mask == 0

tensor([[[ True, False, False,  ..., False, False, False],
         [ True,  True, False,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         ...,
         [ True,  True,  True,  ...,  True, False, False],
         [ True,  True,  True,  ...,  True,  True, False],
         [ True,  True,  True,  ...,  True,  True,  True]]])

In [249]:
def subsequent_mask(size):
    attn_shape = (1, size, size)
    mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(torch.uint8)
    return mask == 0

Below the attention mask shows the position each tgt word (row) is allowed to look at (column). Words are blocked for attention to future words during training

In [250]:
LS_data = pd.concat(
    [
        pd.DataFrame(
            {
                "Mask": subsequent_mask(20)[0][x, y].flatten(),
                "Window": y,
                "Masking": x,
            }
        )
        for y in range(20)
        for x in range(20)
    ]
)

In [251]:
LS_data

Unnamed: 0,Mask,Window,Masking
0,True,0,0
0,True,0,1
0,True,0,2
0,True,0,3
0,True,0,4
...,...,...,...
0,False,19,15
0,False,19,16
0,False,19,17
0,False,19,18


In [252]:
# alt.Chart(LS_data).mark_rect().properties(height=500*1.2, width=500*1.2).encode(alt.X("Window:O"), alt.Y("Masking:O"), alt.Color("Mask:Q", scale=alt.Scale(scheme="viridis"))).interactive()

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compability function of query with the corresponding key.

We call our particular attention 'scaled dot-product attention'. The input consist of queries and keys of dimension dk, and values of dimension dv. We compute the dot products of the query with all keys, divide each by sqrt(dk) and apply a softmax function to obtain the weights on the values. 

In [253]:
def attention(query, key, value, mask, dropout):
    d_k = key.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
    p_attn = scores.softmax(dim=-1)
    if dropout is not None: p_attn = dropout(p_attn)
    return torch.matmul(p_attn, value), p_attn

In [254]:
# My silly idea of creating deep attention
# my intuition is that a non-linear function may be able to capture more contextual representations, as opposed to our linear counterpart
class DeepAttention(nn.Module):
    def __init__(self, seq_len, embed_size, ff_hidden, act_fn=nn.ReLU):
        super().__init__()
        self.qk_deep = nn.Sequential(
            nn.Linear(embed_size*2, ff_hidden),
            act_fn(),
            nn.Linear(ff_hidden, embed_size*2),
        )
        self.dk_norm = nn.LayerNorm([seq_len, embed_size*2])

    def forward(self, query, key, value, mask=None, dropout=None):
        qk = torch.concat([query, key], dim=-1)
        scores = self.dk_norm(self.qk_deep(qk))
        if mask is not None: scores = scores.masked_fill(mask == 0, -1e9)
        p_attn = scores.softmax(dim=-1)
        if dropout is not None: p_attn = dropout(p_attn)
        return torch.matmul(p_attn, value), p_attn

The two most commonly used attention functions are additive attention and dot-product attention. Dot-product attention is identical to our algorithm, except for the scaling factore of sqrt(d_k). Additive attention computes the compability function using a feed-forward network with a single hidden layer. While the two are similar in theoritical complexity, dot-product attention is much faster and more space-efficient in practice, since it can be implemented using highly optimized multiplication code. 

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attenion head, averaging inibits this.


## Multihead vs Deep Attention

### Multi-head Attention
Mutlihead attention, by virtue of its design, captures different aspects or 'views' of the data by splitting the input into multiple heads and then concatenating them. This allows for more complex and expressive representations, fascilitating the model's ability to understand various facets of the data. 

### Deep Attention
My deep attention mechanism introduces non-linear transformations into the attention mechanism. This could allow each attention head to capture more intricate relationships in the data, which are not strictly linear. This could be particularly useful if you suspect the relationship in your data are inherently non-linear or complex. 

### Simple explanation
Imagine you're a chef trying to make the perfect stew. Multi-head attention is like having several different cooks, each specializing in a specific ingredient or cooking technique. Say, one for spices, another for vegatables, and another for meats. Each cook focuses on making their part of the stew as tasy as possible. In the end, you mix all these specialized parts together to create a complex and flavorful dish. 

Now, deep attention is like giving each of these specialized cooks a set of advanced kitchen tools and techniques. These aren't just ordinary knives and pots but high-tech gadgets that can capture flavors in novel ways, say a sous-vide machine or a nitrogen freezer for flash-freezing herbs.

In [255]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, d_model, dropout=0.1, deep_attn=None, **kwargs):
        super().__init__()
        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.deep_attn = DeepAttention(**kwargs) if deep_attn else deep_attn
        self.attn_value = None
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None: mask = mask.unsqueeze(1)
        nbatches = query.size(0)

        # 1. linear projections in batch from d_model -> n_heads * d_k
        # note if there are 8 heads and d_model is 512 then d_k is 64
        query, key, value = [
            lin(x).view(nbatches, -1, self.n_heads, self.d_k).transpose(1, 2)
            for lin, x in zip(self.linears, (query, key, value))
        ]

        # 2. Apply attention on all projected vectors in batch.
        if self.deep_attn is not None: 
            x, self.attn_value = self.deep_attn.forward(query, key, value, mask=mask, dropout=self.dropout)
        else: 
            x, self.attn_value = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # 3. concat using a view and apply a final linear
        x = (
            x.transpose(1, 2)
            .contiguous()
            .view(nbatches, -1, self.n_heads * self.d_k)
        )

        return self.linears[-1](x)

## Applications of Attention in our Model
The Transformer uses multi-head attention in three different ways: 1. In 'encoder-decoder attention' layers, the queries come from the previous decoder layer, and the memory keys and values come from the output of the encoder. This allows every position in the decoder to attend over all positions in the input sequence. This mimics the typical encoder-decoder attention mechanism in sequene-to-sequence models

1. The encoder contains self-attention layers. In a self-attention layer all of the keys, values, and queries come from the same place, in this case, the output of the previous layer in the encoder. Each position in the encoder can attend to all positions in the previous layer of the encoder
2. Similarly, self-attention layers in the decoder allow each position in the decoder to attend to all positions in decoder up to and including that position. We need to prevent leftward information flow in the decoder to preserve the auto-regressive property. We implement this inside of scaled dot-product attention by masking out (setting to -inf) all values in the input of the softmax which correspond to illegal connections

## Position-wise Feed-Forward Networks
In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically. This consist of two linear transformations with a ReLU activation in between

While the linear transformations are the same accross different positions, they use different parameters from layer to layer. Another way of describing this is as two convolutions with ker
nel size 1. The dimensionality of input and output is dmodel=512 and the inner-layer has dimensionality dff=2048

In [256]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1, act=nn.ReLU):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.act = act()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.act(x)
        x = self.dropout(x)
        return self.linear2(x)

In [257]:
class Embedding(nn.Module):
    def __init__(self, d_model, vocab):
        super().__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

### Positional Encoding
Since our model contains no recurrence and no convolution, in order for the model to make use of the order of the seqeuence, we must inject some information about the realtive or absolute position of the tokens in the sequence. To this end, we add 'positonal encodings' to the input embeddings at the bottoms of the encoder and decoder stacks. The positional encodings have the same dimension dmodel as the embeddings, so that the two can be summed. There are many choices for postional encodings.

In [258]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # Create a zero tensor for positional encodings with dimensions (max_len, d_model)
        pe = torch.zeros(max_len, d_model)

        # generate an array of positions from 0 to max_len and reshape it
        position = torch.arange(0, max_len).unsqueeze(1)

        # compute the term used for the sine and cosine functions. This makes sure that the
        # positional encoding varies in a predictable manner
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )

        # apply the sine function to even indices
        pe[:, 0::2] = torch.sin(position * div_term)

        # apply cosine function to odd indices
        pe[:, 1::2] = torch.cos(position * div_term)

        # add an extra dimension to pe tensor and register as buffer
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        # The postional encoding is sliced to match the size of the input
        # input sequence and we make sure gradients are false
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [259]:
pe = PositionalEncoding(20, 0)
y = pe.forward(torch.zeros(1, 100, 20))

data = pd.concat(
    [
        pd.DataFrame(
            {
                "embedding": y[0, :, dim],
                "dimension": dim,
                "position": list(range(100)),
            }
        )
        for dim in [4, 5, 6, 7]
    ]
)

In [260]:
alt.Chart(data).mark_line().properties(width=800).encode(x="position", y="embedding", color="dimension:N")

## Making the model

In [261]:
def make_model(src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.01, act=nn.ReLU, **kwargs):
    c = copy.deepcopy
    attn = MultiHeadAttention(h, d_model, **kwargs)
    ff = PositionWiseFeedForward(d_model, d_ff, dropout, act)
    position = PositionalEncoding(d_model, dropout)
    model = EncoderDecoder(
        Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N),
        Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N),
        nn.Sequential(Embedding(d_model, src_vocab), c(position)),
        nn.Sequential(Embedding(d_model, tgt_vocab), c(position)),
        Generator(d_model, tgt_vocab)
    )

    # initialize from paper
    for p in model.parameters():
        if p.dim() > 1: nn.init.xavier_uniform_(p)
    return model

### Inference
Here we make a forward step to generate a prediction of the model. We try to use our transformer to memorize the input. As you will see the output is randomly generated due to the fact that the model is not trained yet. In the next tutorial we will build the training function and try to train our model to memorize numbers from 1 to 10

In [265]:
for _ in range(10):
    test_model = make_model(11, 11, 2)
    test_model.eval()
    src = torch.LongTensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
    src_mask = torch.ones(1, 1, 10)
    memory = test_model.encode(src, src_mask)
    ys = torch.zeros(1, 1).type_as(src)
    for _ in range(9):
        out = test_model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        prob = test_model.generator(out[:,-1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat([ys, torch.empty(1, 1).type_as(src.data).fill_(next_word)], dim=1)
    print(f"Example untrained model prediction: ", ys)

Example untrained model prediction:  tensor([[0, 5, 2, 9, 9, 9, 0, 8, 5, 2]])
Example untrained model prediction:  tensor([[0, 8, 8, 0, 8, 0, 8, 0, 8, 0]])
Example untrained model prediction:  tensor([[0, 5, 3, 2, 3, 2, 3, 2, 3, 2]])
Example untrained model prediction:  tensor([[0, 3, 3, 3, 3, 3, 3, 3, 3, 3]])
Example untrained model prediction:  tensor([[ 0,  8, 10,  6, 10,  8, 10,  8,  0,  0]])
Example untrained model prediction:  tensor([[0, 4, 2, 9, 1, 9, 1, 9, 1, 9]])
Example untrained model prediction:  tensor([[0, 1, 7, 7, 1, 6, 6, 6, 6, 6]])
Example untrained model prediction:  tensor([[0, 9, 4, 8, 9, 8, 9, 4, 8, 9]])
Example untrained model prediction:  tensor([[0, 2, 9, 2, 9, 2, 9, 2, 2, 9]])
Example untrained model prediction:  tensor([[0, 7, 3, 5, 0, 7, 3, 5, 0, 7]])


## Model Training
First we will define a batch object that holds the src and target sentences for training, as well as constructing the masks

In [266]:
class Batch:
    def __init__(self, src, tgt=None, pad=2):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if tgt is not None:
            self.tgt = tgt[:, :-1]
            self.tgt_y = tgt[:, 1:]
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y != pad).data.sum()

    @staticmethod
    def make_std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2)
        tgt_mask = tgt_mask & subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)
        return tgt_mask

### Training loop

In [267]:
class TrainState:
    step: int = 0
    accum_step: int =0
    samples: int = 0
    tokens: int = 0

In [268]:
def run_epoch(data_iter, model, loss_compute, optimizer, scheduler, mode="train", accum_iter=1, train_state=TrainState()):
    start = start.time()
    total_tokens = 0
    total_loss = 0
    tokens = 0
    n_accum = 0
    for i, batch in enumerate(data_iter):
        out = model.forward(
            batch.src, batch.tgt, batch.src_mask, batch.tgt_mask
        )
        loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens)
        if mode == "train" or mode == "train+log":
            loss_node.backward()
            train_state.step += 1
            train_state.samples += batch.src.shape[0]
            train_state.tokens += batch.ntokens
            if i % accum_iter == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=False)
                n_accum += 1
                train_state.accum_step += 1
            scheduler.step()

        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == "train" or mode == "train+log"):
            lr = optimizer.param_groups[0]['lr']
            elapsed = time.time() - start
            print(f"Epoch Step: {i} | Accumulation Step: {n_accum} | Loss: {loss / batch.ntokens} | Tokens / Sec: {tokens / elapsed} | Learning Rate: {lr}")
            start = time.time()
            tokens = 0

    return total_loss / total_tokens, train_state
