# Transformers from Scratch

To gain a thorough understanding of Transformers, I want to attempt writing one from scratch using PyTorch. I'll be referring to Peter Bloem's [blog post]().

## Basic self-attention

Basic self-attention takes a sequence of input vectors $x_1, x_2, ..., x_t$ and produces a sequence of output vectors $y_1, y_2, ..., y_t$. The vectors all have a dimension of $k$.

To produce output vector $yi$, the self attention operation simply takes a *weighted average over all the input vectors:*

$$ y_i = \sum_{j} w_{ij}x_j \$$

Where $j$ indexes over the whole sequence nd the weights sume to $1$ over all $j$. The weight $w_{ij}$ is not a parameter, but  it is *derived* from a function over $x_i$ and $x_j$. The simplest option for this function is the dot product:

$$ w'_{ij} = {x_i}^{T}{x_j}$$

>**Note:** $x_i$ is the input vector at the same position as the current output vector $y_i$.
>For the next output vector, we get an entirely new series of dot products, and a different weighted sum.

The dot product gives a value between $[-\infty, \infty]$ so we'll use a $softmax$ to map the values to between $[0, 1]$:

$$ 
softmax(w'_{ij}) = 
\frac
{\exp(w'_{ij})}
{{\sum}_j \exp(w'_{ij})}
$$

And that's the basic operation of self-attention.

![A visual illustration of basic self-attention. Note that the $softmax$ operation over the weights is not illustrated](../assets/self-attention.svg)
*A visual illustration of basic self-attention. Note that the $softmax$ operation over the weights is not illustrated*

## In PyTorch: basic self-attention

We'll represent the input, a sequence of $t$ vectors of dimension $k$ as a $t$ by $k$ matrix $X$.
Including a minibatch dimension $b$, gives us an input tensor of shape $(b, t, k)$.

The set of all raw dot products $w'_{ij}$ forms a matrix, which we can compute simply by multiplying $X$ by it's transpose:

In [1]:
import torch
import torch.nn.functional as F

batch_size = 2            # b
sequence_len = 3          # t
input_dimension = 4       # k

X = torch.ones((batch_size, sequence_len, input_dimension))

# - torch.bmm(input: Tensor, mat2: Tensor) -> Tensor
#   Performs a batched matrix multiplication of input and mat2.
#   It applies matrix multiplication over batches of matrices.
#   input and mat2 must be 3-D tensors each containing the same number of matrices.
#
#   If input is (b × n × m) and out is (b × m × p) out will be (b × n × p) 
#   It can be thought of as b((n × m) (m × p)) = (b × n × p).
#
# - torch.transpose(input: Tensor, dim0: int, dim1: int) -> Tensor 
#   Returns a transpose of the input tensor.
#   The given dimensions dim0 and dim1 are swapped.
#   - dim0 is the first dimension to be transposed.
#   - dim1 is the second dimension to be transposed.
# 
# We transpose dimension 1 and 2 because we want to transpose
# the t by k vector containing the weights for a particular input.
# (b, t, k) x (b, k, t)
raw_weights =  torch.bmm(X, X.transpose(dim0=1, dim1=2))
# softmax(w'_ij)
weights = F.softmax(raw_weights, dim=2)

# w'_
y = torch.bmm(weights, X)

# And that's all, two matrix multiplications and one softmax gives
# us a basic self-attention.

## In PyTorch: Complete Self-Attention

In [4]:
import torch
from torch import nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, k, heads=4, mask=False) -> None:
        super().__init__()

        # Assert the embedding dimension needs to be divisible by the number of heads.
        assert k % heads == 0

        self.k, self.heads = k, heads

        # These compute the queries, keys and values for all heads.
        self.to_keys    = nn.Linear(k, k, bias=False)
        self.to_queries = nn.Linear(k, k, bias=False)
        self.to_values  = nn.Linear(k, k, bias=False)

        # This will be applie after the multi-head attention operation.
        self.unify_heads = nn.Linear(k, k)

    def forward(self, x):
        batch_size, sequence_len, input_dimension = x.size()
        number_of_heads = self.heads

        # This gives us three vector sequences of the full embedding dimension
        keys    = self.to_keys(x)
        queries = self.to_queries(x)
        values  = self.to_values(x)

        # The size of the scaled down dimension of the multi-head attention heads.
        slice_len = input_dimension // number_of_heads

        # Reshape from size (b, t, k) -> (b, t, h, s) 
        # b - Batch size
        # t - Sequence length
        # k - Original dimension of input vector.
        # h - Number of attention heads.
        # s - Dimension of key, query and value for an attention head. (h/k)
        keys    = keys.view(batch_size, sequence_len, number_of_heads, slice_len)
        queries = queries.view(batch_size, sequence_len, number_of_heads, slice_len)
        values = values.view(batch_size, sequence_len, number_of_heads, slice_len)

        # To compute the dot products, we fold the heads into the batch dimension.
        # This ensures we can use torch.bmm() as before.
        # Since the head and batch dimension are not next to each other, we need to transpose before we reshape.
        keys    = keys.transpose(1, 2).contiguous().view(batch_size * number_of_heads, sequence_len, slice_len) 
        queries = queries.transpose(1, 2).contiguous().view(batch_size * number_of_heads, sequence_len, slice_len) 
        values  = values.transpose(1, 2).contiguous().view(batch_size * number_of_heads, sequence_len, slice_len) 

        # As before the dot products can be computed in a single matrix multiplication,
        # but now between the queries and the keys.
        
        # Get the dot product of the queries, keys and scale.
        dot = torch.bmm(queries, keys.transpose(1, 2))
        # -- dot has size (b*h, t, t) containing raw weights.

        # Scale the dot product
        dot = dot / (input_dimension ** 0.5)

        # Normalize
        dot = F.softmax(dot, dim=2)
        # - dot now contains row-wise normalized weights
      
        # Apply self-attention to the values.
        out = torch.bmm(dot, values).view(batch_size, number_of_heads, sequence_len, slice_len)

        # To unify the attention heads, we transpose again, so that the head dimension
        # and the embedding dimension are next to each other and reshape to get concatenated
        # vectors of dimension e.
        out = out.transpose(1, 2).contiguous().view(batch_size, sequence_len, slice_len*number_of_heads)

        return self.unify_heads(out)

## Building Transformers

A transformer is not just a self-attention layer, it is an *architectur*. We'll define it as: 

*Any architecture designed to process a connected set of units-such as the tokens in a sequence of the pixesl in an image-where the only interaction between the units is through self-attention.*

As with other mechanisms, like convolutions, a more or less standard approach has emerged for how to
build self-attention layers into a larger network.
The first step is to wrap the self-attention into a **block** that we can repeat.

## The Transformer Block

There are some variations on how to build a basic transformer block, but most of them are structured roughly like this:

![Transformer Block](../assets/transformer-block.svg)

That is, the block applies in sequence:

1. A Self-attention layer
2. A Layer Normalization
3. A feed-forward layer (a single MLP applied independently to each vector)
4. Anther Layer Normalization

The order isn't set in stone the important thing is to combine self-attention with a local feedforward and to add normalization and residual connections.

>**Note:** Normalization and residual connections are standard tricks used to help deep neural networks train faster and more accurately.
>The layer normalization is applied over the embedding dimension only.

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, k, heads):
        super().__init__()

        self.attention = SelfAttention(k=k, heads=heads)

        self.norm_1 = nn.LayerNorm(k)
        self.norm_2 = nn.LayerNorm(k)

        self.ff = nn.Sequential(
            nn.Linear(k, 4*4),
            nn.ReLU(),
            nn.Linear(4*k, k),
        )

    def forward(self, x):
        attended = self.attention(x)
        x = self.norm_1(attended + x)

        feedforward = self.ff(x)
        return self.norm_2(feedforward + x)