## Transformer

Transformer in a deep learning context refers to the combination of mult-headed attention, combined with a couple of additional elements: layer normalization, feedforward network, and residual skip connections. Let's take each of these one-by-one:

### 1. Layer Normalization

> Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the affine option, Layer Normalization applies per-element scale and bias with elementwise_affine ([torch docs](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)).

When you think of a single token's representation, it can get spikey. In an effort to smooth them out, we can use layer normalization. What does this mean? 

You subtract the mean and divide by the standard deviation. And you learn parameters, one or two for each embedding dimension. One if you disable the bias, two if you enable it.

$y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$

Here's what it looks like in PyTorch and vanilla matrix multiplication.

In [11]:
import torch
torch.set_printoptions(sci_mode=False, linewidth=160)

In [1]:
import torch
import torch.nn as nn

batch_size = 2
context_size = 3
n_embd = 4
bias=False
x = torch.randint(high=4, size=(batch_size, context_size, n_embd), dtype=torch.float)
print(x)

layer_norm = nn.LayerNorm(n_embd, bias=bias)
print(layer_norm.weight)

output = layer_norm(x)
output

tensor([[[3., 0., 1., 3.],
         [2., 2., 0., 2.],
         [1., 0., 3., 3.]],

        [[2., 2., 1., 1.],
         [1., 2., 3., 2.],
         [0., 0., 1., 0.]]])
Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)


tensor([[[ 0.9622, -1.3471, -0.5773,  0.9622],
         [ 0.5773,  0.5773, -1.7320,  0.5773],
         [-0.5773, -1.3471,  0.9622,  0.9622]],

        [[ 1.0000,  1.0000, -1.0000, -1.0000],
         [-1.4142,  0.0000,  1.4142,  0.0000],
         [-0.5773, -0.5773,  1.7320, -0.5773]]],
       grad_fn=<NativeLayerNormBackward0>)

In [2]:
eps = 1e-5
output2 = (x - torch.mean(x, dim=-1, keepdim=True)) / (torch.std(x, dim=-1, keepdim=True, unbiased=False) + eps)
output2

tensor([[[ 0.9622, -1.3471, -0.5773,  0.9622],
         [ 0.5773,  0.5773, -1.7320,  0.5773],
         [-0.5773, -1.3471,  0.9622,  0.9622]],

        [[ 1.0000,  1.0000, -1.0000, -1.0000],
         [-1.4142,  0.0000,  1.4142,  0.0000],
         [-0.5773, -0.5773,  1.7320, -0.5773]]])

In [3]:
gamma = torch.ones(n_embd) * 2
gamma

tensor([2., 2., 2., 2.])

In [4]:
output2 * gamma

tensor([[[ 1.9245, -2.6943, -1.1547,  1.9245],
         [ 1.1547,  1.1547, -3.4641,  1.1547],
         [-1.1547, -2.6943,  1.9245,  1.9245]],

        [[ 2.0000,  2.0000, -2.0000, -2.0000],
         [-2.8284,  0.0000,  2.8284,  0.0000],
         [-1.1547, -1.1547,  3.4640, -1.1547]]])

In [5]:
print(f"Outputs are equivalent: {torch.allclose(output, output2)}")

Outputs are equivalent: True


### 2. Feed Forward

* This operates on a per-token level, across the entire embedding space.
* Information from other tokens is gathered by the dot-product from the Attention.
* Then the model needs to "think" on that information it has gathered.

In [1]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, n_embd, bias=False):
        super().__init__()
        self.linear1 = nn.Linear(n_embd, 4 * n_embd, bias=bias)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(4 * n_embd, n_embd, bias=bias)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        return self.linear2(x)


batch_size = 2; context_size = 3; n_embd = 4
x = torch.rand(size=(batch_size, context_size, n_embd))

ffwd = FeedForward(n_embd)
ffwd(x)

tensor([[[-0.0568,  0.0302, -0.0024,  0.1023],
         [ 0.0030,  0.0199, -0.0254,  0.0903],
         [-0.0646, -0.0231, -0.0872,  0.0039]],

        [[-0.0100, -0.0171, -0.1008, -0.0094],
         [-0.0560,  0.0301,  0.0239,  0.0919],
         [ 0.0095, -0.0186, -0.0808, -0.0112]]], grad_fn=<UnsafeViewBackward0>)

In [2]:
print(f"Number params = {sum(p.numel() for p in ffwd.parameters() if p.requires_grad)}")

Number params = 128


### 3. Attention

* Information is gathered from other tokens in the context sequence.
* The mechanism is the humble pairwise dot product between all tokens combination.
* A single sequence provides multiple training examples through triangular masking.
* Input and output dimension is the same, `(batch_size, context_size, n_embd)`.

In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(538)

class Attention(nn.Module):
    def __init__(self, n_embd, context_size, bias=False):
        super().__init__()
        self.key = nn.Linear(n_embd, n_embd, bias=bias)
        self.query = nn.Linear(n_embd, n_embd, bias=bias)
        self.value = nn.Linear(n_embd, n_embd, bias=bias)
        self.proj = nn.Linear(n_embd, n_embd, bias=bias)

    def forward(self, x):
        batch_size, context_size, n_embd = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        attn_logits = q @ k.transpose(-1, -2) * 1 / math.sqrt(k.shape[-1])
        mask = torch.tril(torch.ones(context_size, context_size))
        attn_logits = attn_logits.masked_fill(mask == 0, -1e9)
        attn_prob = F.softmax(attn_logits, dim=-1)
        attn_out = attn_prob @ v
        return self.proj(attn_out)

batch_size = 2; context_size = 3; n_embd = 4
x = torch.rand(size=(batch_size, context_size, n_embd))
attn = Attention(n_embd, context_size)
attn(x)

tensor([[[ 0.1911,  0.0591, -0.1139,  0.1038],
         [-0.0683, -0.0674, -0.1091, -0.0690],
         [-0.1098, -0.0891, -0.1046, -0.0989]],

        [[ 0.1770,  0.0769, -0.0124,  0.1006],
         [ 0.1136,  0.0387, -0.0391,  0.0522],
         [ 0.1072,  0.0291, -0.0672,  0.0477]]], grad_fn=<UnsafeViewBackward0>)

scaled_dot_product_attention: 
* Compute the pairwise similarity of all of the tokens in the sequence.
* Much faster, requires pytorch 2.x
* The mask is applied within the function due to `is_causal=True` argument.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(538)

class Attention(nn.Module):
    def __init__(self, n_embd, context_size, bias=False):
        super().__init__()
        self.key = nn.Linear(n_embd, n_embd, bias=bias)
        self.query = nn.Linear(n_embd, n_embd, bias=bias)
        self.value = nn.Linear(n_embd, n_embd, bias=bias)
        self.proj = nn.Linear(n_embd, n_embd, bias=bias)

    def forward(self, x):
        batch_size, context_size, n_embd = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True)
        return self.proj(out)

batch_size = 2; context_size = 3; n_embd = 4

x = torch.rand(size=(batch_size, context_size, n_embd))

attn = Attention(n_embd, context_size)
attn(x)

tensor([[[ 0.1911,  0.0591, -0.1139,  0.1038],
         [-0.0683, -0.0674, -0.1091, -0.0690],
         [-0.1098, -0.0891, -0.1046, -0.0989]],

        [[ 0.1770,  0.0769, -0.0124,  0.1006],
         [ 0.1136,  0.0387, -0.0391,  0.0522],
         [ 0.1072,  0.0291, -0.0672,  0.0477]]], grad_fn=<UnsafeViewBackward0>)

### 4. Multi-Head Attention

* Information is gathered from other tokens in the context sequence.
* The mechanism is the humble pairwise dot product between all tokens combination.
* A single sequence provides multiple training examples through triangular masking.
* Input and output dimension is the same, `(batch_size, context_size, n_embd)`.

Mult-head logic:

* Compute self-attention for each head independently.
* Convert the head to a batch dimension by:
    (1) breaking up the embedding dimension into their individual heads, then
    (2) swapping the context_size and n_head
* Intermediate dimensionality is  `(batch_size, n_head, context_size,  n_embd // n_head)`

scaled_dot_product_attention: 
* Compute the pairwise similarity of all of the tokens in the sequence.
* Batch dimensions are the training example and the attention head.
* The mask is applied within the function due to `is_causal=True` argument.


Output transform:

* Convert back to 3D shape
* Contiguous is required here, which refers to creating a new tensor


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiheadedAttention(nn.Module):
    def __init__(self, head_size, n_head, n_embd, context_size, bias=False):
        super().__init__()
        self.n_head = n_head
        self.key = nn.Linear(n_embd, n_embd, bias=bias)
        self.query = nn.Linear(n_embd, n_embd, bias=bias)
        self.value = nn.Linear(n_embd, n_embd, bias=bias)
        self.proj = nn.Linear(n_embd, n_embd, bias=bias)

    def forward(self, x):
        batch_size, context_size, n_embd = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        k = k.view(batch_size, context_size, self.n_head, n_embd // self.n_head).transpose(1, 2)
        q = q.view(batch_size, context_size, self.n_head, n_embd // self.n_head).transpose(1, 2)
        v = v.view(batch_size, context_size, self.n_head, n_embd // self.n_head).transpose(1, 2)
        out = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(batch_size, context_size, n_embd)
        return self.proj(out)

batch_size = 2; context_size = 3; n_embd = 12; n_head = 4

x = torch.rand(size=(batch_size, context_size, n_embd))

mha = MultiheadedAttention(n_embd // n_head, n_head, n_embd, context_size)
mha(x)

tensor([[[-0.0695, -0.1569, -0.0522,  0.0398,  0.1019,  0.1882,  0.2080,  0.0817, -0.0100,  0.0229,  0.0451, -0.2700],
         [-0.1275, -0.1344, -0.1049, -0.0612,  0.1591,  0.2479,  0.1442,  0.1388, -0.0788, -0.0085, -0.0199, -0.2941],
         [-0.0926, -0.0455, -0.0927, -0.1061,  0.1230,  0.2879,  0.1255,  0.1476, -0.1061,  0.0566, -0.0785, -0.2065]],

        [[-0.0419, -0.0335, -0.0098, -0.0156,  0.0757,  0.3924,  0.3331,  0.0867,  0.0588,  0.1902, -0.0457, -0.2093],
         [-0.0660, -0.0605, -0.0492, -0.0656,  0.0940,  0.2376,  0.2284,  0.1606,  0.0476,  0.0316, -0.0006, -0.1983],
         [-0.1252, -0.0439, -0.0597, -0.0248,  0.0417,  0.3308,  0.1892,  0.1849, -0.0597,  0.0695, -0.0121, -0.1741]]], grad_fn=<UnsafeViewBackward0>)