`pip install torch torchtext`

# Setup

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import random
import math
import string

from typing import Float, List, Tuple
from torch import Tensor



In 2017, Google Brain release a paper titled Attention is All you Need, where they departed from the traditional Seq2Seq architectures used in NLP and introduced a new kind of architecture: the transformer. As the title of their paper suggests, the transformer is based on attention mechanisms, and does not use any kind of recurrent states as in RNNs. This meant that a transformer could be trained significantly faster than a Seq2Seq model using RNNs, as sequences were no longer processed one token at a time, but rather all at once. It also significantly reduced the complexity of the models being worked with. As opposed to having complex Seq2Seq variations that can quickly become intractable, the transformer is composed of relatively straightforward matrix multiplications all the way through.

Since their introduction, transformers have been taking over the world of NLP. GPT-3, a state-of-the-art language model capable of outputting text nearly impossible to distinguish from something human-written, is a massive transformer model. Transformers have also seen applications in other fields of machine learning like computer vision, as their expresive power is, well, powerful.

In this notebook, we'll use our intuition from the previous section to build a transformer from scratch. The notebook will break down each component of the transformer for you to implement and understand yourself!

# The Transformer Block

A transformer consists of a series of `n_layers` blocks.

## Attention

First, let's look at the attention layer in detail.

The job of the attention layer is to allow a token to "see" information from faraway long-ago tokens. It does this by forming a "query" vector that it uses to match other token "key" vectors. Then the information that it sees is put into a set of "value" vectors that are summed up and added to the original querying token state.

We'll load in our set of saved attention heads from gpt-2 small and take a look at some common patterns.

In [30]:
import circuitsvis as cv

str_tokens = torch.load('tokens.pt')
attn_patterns = torch.load('attn_patterns.pt')

display(
    cv.attention.attention_patterns(
        tokens=str_tokens, 
        attention=attn_patterns.squeeze(0)
    )
)

Take a look at some of the patterns.

We notice that there are three basic patterns which repeat quite frequently:

* `prev_token_heads`, which attend mainly to the previous token (e.g. head `0.7`)
* `current_token_heads`, which attend mainly to the current token (e.g. head `0.1`)
* `first_token_heads`, which attend mainly to the first token (e.g. heads `0.0`, although these are a bit less clear-cut than the other two)

The `prev_token_heads` and `current_token_heads` are perhaps unsurprising, because words that are close together in a sequence probably have a lot more mutual information (i.e. we could get quite far using bigram or trigram prediction).

The `first_token_heads` are a bit more surprising. The basic intuition here is that the first token in a sequence is often used as a resting or null position for heads that only sometimes activate (since our attention probabilities always have to add up to 1).

\begin{equation}
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V
\end{equation}

In [None]:
class SingleHeadAttention(nn.Module):
    def __init__(self, d):
        """
        Here we will assume that the input dimensions are same as the
        output dims.
        """
        super().__init__()
        # SOLUTION
        self.q_layer = torch.nn.Linear(d, d)
        self.k_layer = torch.nn.Linear(d, d)
        self.v_layer = torch.nn.Linear(d, d)

    def forward(self, x, mask=None, return_weights=False):
        """
        Assume x is <t x d> -- t being the sequence length, d
        the embed dims.

        W_q, W_k, and W_v are weights for projecting into queries,
        keys, and values, respectively. Here these will have shape
        <d x t>, yielding d dimensional vectors for each input.

        This function should return a t dimensional attention vector
        for each input -- i.e., an attention matrix with shape <t x t>,
        and the values derived from this <t x d>.

        Derive Q, K, V matrices, then self attention weights. These should
        be used to compute the final representations (t x d); optionally
        return the weights matrix if `return_weights=True`.
        """
        # SOLUTION
        Q = self.q_layer(x)
        K = self.k_layer(x)
        V = self.v_layer(x)

        A = Q @ K.transpose(-2, -1)
        if mask is not None:
            A = A.masked_fill(mask == 0, -1e9)
        weights = F.softmax(A, dim=-1)

        if return_weights:
          return weights, weights @ V

        return weights @ V

This can all be done in parallel.

## Layer Norm

\begin{equation}
y = \frac{x - \mathbb{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} \cdot \gamma + \beta
\end{equation}

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.w = nn.Parameter(torch.ones(d_model))
        self.b = nn.Parameter(torch.zeros(d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
        # SOLUTION
        residual_mean = residual.mean(dim=-1, keepdim=True)
        residual_std = (residual.var(dim=-1, keepdim=True, unbiased=False) + self.eps).sqrt()

        residual = (residual - residual_mean) / residual_std
        return residual * self.w + self.b

## Self Attention

In [None]:
class SelfAttentionLayer(nn.Module):
    def __init__(self, embed_dim):
        super(SelfAttentionLayer, self).__init__()
        self.self_attn = SingleHeadAttention(embed_dim)
        self.norm1 = LayerNorm(embed_dim)

    def forward(self, src, mask=None):
        src = src + self.norm1(self.self_attn(src, mask=mask))
        return src

## Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(100.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0), :]

Now we put lots of attention layers into a transformer model.

Around each attention layer we create a residual structure, and we also use LayerNorm to prevent things from blowing up.

The last detail is Positional encoding - a vector we add to each token based on where it is, rather than what it is.  This allows a token to be "queried" according to position rather than content.

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_layers):
        super(TransformerModel, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim)
        self.layers = nn.ModuleList([SelfAttentionLayer(embed_dim) for _ in range(num_layers)])
        self.norm = nn.LayerNorm(embed_dim)
        self.decoder = nn.Linear(embed_dim, vocab_size)

    def forward(self, src):
        mask = torch.triu(torch.ones(
            src.size(1), src.size(1), dtype=torch.bool, device=src.device)).T
        src = self.embed(src)
        src = self.pos_encoder(src)
        for layer in self.layers:
            src = layer(src, mask)
        src = self.norm(src)
        output = self.decoder(src)
        return output
