## Transformer

Transformer in a deep learning context refers to the combination of two sub-layers: multi-headed attention and feedforward, interspersed with layer normalization, and residual skip connections that encompase the sub-layers. 

Let's take each of these one-by-one:

### 1. Layer Normalization

$y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$

* When you think of a single token's representation, it can get spikey. In an effort to smooth them out, we can use layer normalization.
* You subtract the mean and divide by the standard deviation where these values are calculated across the token's embedding. 
* Then you multiply each token embedding by a learned vector $\gamma$. Optionally you also add a learned bias vector $\beta$.

References: 
* torch docs: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html

In [9]:
import torch
torch.set_printoptions(sci_mode=False, linewidth=160)

In [10]:
import torch
import torch.nn as nn

# 1. Creating data input
batch_size = 2; context_size = 3; n_embd = 4
x = torch.randint(high=4, size=(batch_size, context_size, n_embd), dtype=torch.float)

# 2. Calculating the normalization manually
eps = 1e-5
gamma = torch.ones(n_embd)
beta = torch.zeros(n_embd)
numerator = (x - torch.mean(x, dim=-1, keepdim=True))
denomenator = torch.sqrt(torch.var(x, dim=-1, keepdim=True, correction=0) + eps)
output1 =  numerator / denomenator * gamma + beta

# 3. Using the LayerNorm module
layer_norm = nn.LayerNorm(n_embd)
output2 = layer_norm(x)
torch.allclose(output1, output2, atol=1e-6)

True

### 2. Feed Forward

* This operates on a per-token level, across the entire embedding space.
* Information from other tokens is gathered by the dot-product from the Attention.
* Then the model needs to "think" on that information it has gathered.
* In dense transformers, the feedforward network takes up the majority of the parameters.

In [11]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, n_embd, bias=False):
        super().__init__()
        self.linear1 = nn.Linear(n_embd, 4 * n_embd, bias=bias)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(4 * n_embd, n_embd, bias=bias)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        return self.linear2(x)


batch_size = 2; context_size = 3; n_embd = 4
x = torch.rand(size=(batch_size, context_size, n_embd))

ffwd = FeedForward(n_embd)
ffwd(x)

tensor([[[-0.0505,  0.0126, -0.0237, -0.0140],
         [-0.0584,  0.0065, -0.0475,  0.0038],
         [-0.0210,  0.0022, -0.0162, -0.0439]],

        [[ 0.0992, -0.0262, -0.0854, -0.1161],
         [ 0.0382, -0.0444, -0.1140, -0.0680],
         [ 0.0538,  0.0129, -0.1184, -0.0631]]], grad_fn=<UnsafeViewBackward0>)

### 3. Attention

* Information is gathered from other tokens in the context sequence.
* The mechanism is the humble pairwise dot product between all tokens combination.
* A single sequence provides multiple training examples through triangular masking.
* Input and output dimension is the same, `(batch_size, context_size, n_embd)`.

In [12]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(538)

class Attention(nn.Module):
    def __init__(self, n_embd, context_size, bias=False):
        super().__init__()
        self.key = nn.Linear(n_embd, n_embd, bias=bias)
        self.query = nn.Linear(n_embd, n_embd, bias=bias)
        self.value = nn.Linear(n_embd, n_embd, bias=bias)
        self.proj = nn.Linear(n_embd, n_embd, bias=bias)

    def forward(self, x):
        batch_size, context_size, n_embd = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        attn_logits = q @ k.transpose(-1, -2) * 1 / math.sqrt(k.shape[-1])
        mask = torch.tril(torch.ones(context_size, context_size))
        attn_logits = attn_logits.masked_fill(mask == 0, -1e9)
        attn_prob = F.softmax(attn_logits, dim=-1)
        attn_out = attn_prob @ v
        return self.proj(attn_out)

batch_size = 2; context_size = 3; n_embd = 4
x = torch.rand(size=(batch_size, context_size, n_embd))
attn = Attention(n_embd, context_size)
output1 = attn(x)
output1

tensor([[[ 0.1911,  0.0591, -0.1139,  0.1038],
         [-0.0683, -0.0674, -0.1091, -0.0690],
         [-0.1098, -0.0891, -0.1046, -0.0989]],

        [[ 0.1770,  0.0769, -0.0124,  0.1006],
         [ 0.1136,  0.0387, -0.0391,  0.0522],
         [ 0.1072,  0.0291, -0.0672,  0.0477]]], grad_fn=<UnsafeViewBackward0>)

scaled_dot_product_attention: 
* Much faster attention calculation, requires pytorch 2.x
* The mask is applied within the function due to `is_causal=True` argument.

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(538)

class Attention(nn.Module):
    def __init__(self, n_embd, context_size, bias=False):
        super().__init__()
        self.key = nn.Linear(n_embd, n_embd, bias=bias)
        self.query = nn.Linear(n_embd, n_embd, bias=bias)
        self.value = nn.Linear(n_embd, n_embd, bias=bias)
        self.proj = nn.Linear(n_embd, n_embd, bias=bias)

    def forward(self, x):
        batch_size, context_size, n_embd = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True)
        return self.proj(out)

batch_size = 2; context_size = 3; n_embd = 4

x = torch.rand(size=(batch_size, context_size, n_embd))

attn = Attention(n_embd, context_size)
output2 = attn(x)
print(torch.allclose(output1, output2))
output2

True


tensor([[[ 0.1911,  0.0591, -0.1139,  0.1038],
         [-0.0683, -0.0674, -0.1091, -0.0690],
         [-0.1098, -0.0891, -0.1046, -0.0989]],

        [[ 0.1770,  0.0769, -0.0124,  0.1006],
         [ 0.1136,  0.0387, -0.0391,  0.0522],
         [ 0.1072,  0.0291, -0.0672,  0.0477]]], grad_fn=<UnsafeViewBackward0>)

### 4. Multi-Head Attention

* Information is gathered from other tokens in the context sequence.
* The mechanism is the humble pairwise dot product between all tokens combination.
* A single sequence provides multiple training examples through triangular masking.
* Input and output dimension is the same, `(batch_size, context_size, n_embd)`.

Mult-head logic:

* Compute self-attention for each head independently.
* Convert the head to a batch dimension by:
    (1) breaking up the embedding dimension into their individual using `view` and
    (2) swapping the context_size and n_head using `transpose`
* Intermediate dimensionality is  `(batch_size, n_head, context_size,  n_embd // n_head)`

scaled_dot_product_attention: 
* Compute the pairwise similarity of all of the tokens in the sequence.
* Batch dimensions are the training example and the attention head.
* The mask is applied within the function due to `is_causal=True` argument.

Output transform:

* Convert back to 3D shape
* Contiguous is required here, which refers to creating a new tensor


In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadedAttention(nn.Module):
    def __init__(self, head_size, n_head, n_embd, context_size, bias=False):
        super().__init__()
        self.n_head = n_head
        self.key = nn.Linear(n_embd, n_embd, bias=bias)
        self.query = nn.Linear(n_embd, n_embd, bias=bias)
        self.value = nn.Linear(n_embd, n_embd, bias=bias)
        self.proj = nn.Linear(n_embd, n_embd, bias=bias)

    def forward(self, x):
        batch_size, context_size, n_embd = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        k = k.view(batch_size, context_size, self.n_head, n_embd // self.n_head).transpose(1, 2)
        q = q.view(batch_size, context_size, self.n_head, n_embd // self.n_head).transpose(1, 2)
        v = v.view(batch_size, context_size, self.n_head, n_embd // self.n_head).transpose(1, 2)
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(batch_size, context_size, n_embd)
        return self.proj(out)

batch_size = 2; context_size = 3; n_embd = 12; n_head = 4
x = torch.rand(size=(batch_size, context_size, n_embd))
mha = MultiheadedAttention(n_embd // n_head, n_head, n_embd, context_size)
mha(x)

tensor([[[ 0.0281, -0.2927, -0.2525, -0.0889,  0.1167,  0.1468, -0.0820, -0.2655,  0.1458,  0.0630,  0.0043,  0.1676],
         [-0.0009, -0.1137, -0.3365, -0.0507,  0.2371,  0.0397, -0.1004, -0.1876,  0.0511,  0.1017,  0.0875, -0.0008],
         [-0.0237, -0.0629, -0.3037, -0.0972,  0.2089,  0.0214, -0.0721, -0.1648,  0.0580,  0.0751,  0.0676,  0.0171]],

        [[-0.0217, -0.2904, -0.3444, -0.0450,  0.1474,  0.2053, -0.2469, -0.3188,  0.0560,  0.1473, -0.0262,  0.0488],
         [ 0.0701, -0.1827, -0.3194, -0.0396,  0.2232,  0.1187, -0.0864, -0.2168,  0.0954,  0.0748,  0.0209,  0.0241],
         [ 0.0656, -0.1826, -0.3278, -0.1010,  0.2486,  0.0924, -0.0423, -0.2427,  0.0861,  0.0443, -0.0071,  0.0052]]], grad_fn=<UnsafeViewBackward0>)

## 5. Full Transformer

* We need to assemble the multi-headed attention and the feedforward sub-layers.
* But there's more to the Transformer block than this.
* Also need layer normalization and residual skip connections.

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        head_size = config.n_embd // config.n_head
        self.layer_norm1 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.multi_headed_attention = MultiheadedAttention(head_size, config.n_head, config.n_embd, config.context_size, config.bias)
        self.layer_norm2 = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.feed_forward = FeedForward(config.n_embd, config.bias)

    def forward(self, x):
        x = x + self.multi_headed_attention(self.layer_norm1(x))
        x = x + self.feed_forward(self.layer_norm2(x))
        return x

class Config:
    n_layer = 12
    n_embd = 768
    n_head = 12
    context_size = 1024
    vocab_size = 50_304
    bias = False

batch_size = 2; context_size = 1024; n_embd = 768
x = torch.rand(size=(batch_size, context_size, n_embd))

model = Transformer(Config())
model(x)

tensor([[[ 0.1371, -0.0388,  0.6705,  ...,  0.9291,  1.2065,  0.0958],
         [-0.0254,  1.2155,  0.4800,  ...,  0.9933,  1.2587,  0.2396],
         [ 0.1490,  0.4489,  0.0122,  ...,  1.1132,  0.5381,  1.2893],
         ...,
         [ 0.1049,  0.1704,  0.5994,  ...,  0.7017,  0.9758,  0.0885],
         [ 0.2474,  0.5638,  0.4740,  ...,  0.5415,  1.2365,  0.7050],
         [ 0.5544,  1.2055,  0.7379,  ...,  0.1003,  0.9279,  0.3755]],

        [[ 0.3582,  0.8588, -0.1149,  ...,  0.8905,  0.6593, -0.1255],
         [ 0.8621,  0.4575,  0.5921,  ...,  0.7179,  0.1728,  0.2621],
         [ 1.5342,  0.8590,  0.3469,  ...,  0.1434,  0.4204,  0.4043],
         ...,
         [ 0.6552,  0.6969,  0.8733,  ...,  0.0807,  0.6024,  1.0817],
         [ 0.4097,  0.4047,  0.4242,  ...,  0.8062,  0.8283,  0.6442],
         [ 1.0054,  0.7435,  0.6510,  ..., -0.0248,  0.4735,  0.1544]]], grad_fn=<AddBackward0>)

## 6. Putting the model all together

Finally, we need to add three things:
* The input: position embedding and token embeddding
* The output: final layer normalization and linear projection back into the vocabulary space
* The initialization: scaling all weights by $Normal(0, 0.02)$ and additional scaling of $1\sqrt{N}$ for the weights of residual layers, where $N$ is the number of residual layers (i.e., final linear layer before the residual skip connection).

In [16]:
class GPT2(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.context_size = config.context_size
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.context_size, config.n_embd)
        self.transformers = nn.Sequential(*[Transformer(config) for _ in range(config.n_layer)])
        self.layer_norm_final = nn.LayerNorm(config.n_embd, bias=config.bias)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=config.bias)
        self.lm_head.weight = self.token_embedding_table.weight
        
        self.apply(self._init_weights)
        n_residual_layers = 2 * config.n_layer
        for pn, p in self.named_parameters():
            if pn.endswith("proj.weight") or pn.endswith("linear2.weight"):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(n_residual_layers))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        device = idx.device
        pos_idx = torch.arange(context_size, dtype=torch.long, device=device)
        x = self.token_embedding_table(idx) + self.position_embedding_table(pos_idx)
        x = self.transformers(x)
        x = self.layer_norm_final(x)
        return self.lm_head(x)


class Config:
    n_layer = 12
    n_embd = 768
    n_head = 12
    context_size = 1024
    vocab_size = 50_304
    bias = False

batch_size = 2; context_size = 1024; n_embd = 768
idx = torch.randint(high=50_254, size=(batch_size, context_size))

model = GPT2(Config())
model(idx)

tensor([[[     0.6351,     -0.0949,     -0.5541,  ...,     -0.1742,      0.1028,      0.2494],
         [     0.7023,     -0.6657,     -0.4743,  ...,      0.3498,      0.3731,     -0.0730],
         [     0.1394,     -0.9582,     -0.8774,  ...,      0.6104,      0.0692,     -0.5602],
         ...,
         [     0.9415,      0.8835,     -0.5578,  ...,      0.5964,      0.0697,      0.0509],
         [     0.8785,      0.8142,     -0.0001,  ...,      0.1103,      1.2653,     -0.5562],
         [     0.8092,      0.9330,     -0.5170,  ...,      0.2967,      0.3159,     -0.9122]],

        [[     0.7075,     -0.1322,      0.7380,  ...,      0.3899,     -0.1218,     -0.5173],
         [     0.4188,     -0.0548,      0.3190,  ...,      0.6858,      0.2561,     -0.8678],
         [     0.0219,     -0.3555,     -0.1466,  ...,      0.5057,      0.0858,      0.0092],
         ...,
         [     0.7707,      0.1874,     -0.8775,  ...,      0.4231,      0.2930,     -0.3817],
         [     1.311