## Transformer

Transformer in a deep learning context refers to the combination of mult-headed attention, combined with a couple of additional elements: layer normalization, feedforward network, and residual skip connections. Let's take each of these one-by-one:

### 1. Layer Normalization

> Unlike Batch Normalization and Instance Normalization, which applies scalar scale and bias for each entire channel/plane with the affine option, Layer Normalization applies per-element scale and bias with elementwise_affine ([torch docs](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html)).

When you think of a single token's representation, it can get spikey. In an effort to smooth them out, we can use layer normalization. What does this mean? 

You subtract the mean and divide by the standard deviation. And you learn parameters, one or two for each embedding dimension. One if you disable the bias, two if you enable it.

$y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta$

Here's what it looks like in PyTorch and vanilla matrix multiplication.

In [1]:
import torch
import torch.nn as nn

batch_size = 2
context_size = 3
n_embd = 4
bias=False
x = torch.randint(high=4, size=(batch_size, context_size, n_embd), dtype=torch.float)
print(x)

layer_norm = nn.LayerNorm(n_embd, bias=bias)
print(layer_norm.weight)

output = layer_norm(x)
output

tensor([[[3., 0., 1., 3.],
         [2., 2., 0., 2.],
         [1., 0., 3., 3.]],

        [[2., 2., 1., 1.],
         [1., 2., 3., 2.],
         [0., 0., 1., 0.]]])
Parameter containing:
tensor([1., 1., 1., 1.], requires_grad=True)


tensor([[[ 0.9622, -1.3471, -0.5773,  0.9622],
         [ 0.5773,  0.5773, -1.7320,  0.5773],
         [-0.5773, -1.3471,  0.9622,  0.9622]],

        [[ 1.0000,  1.0000, -1.0000, -1.0000],
         [-1.4142,  0.0000,  1.4142,  0.0000],
         [-0.5773, -0.5773,  1.7320, -0.5773]]],
       grad_fn=<NativeLayerNormBackward0>)

In [2]:
eps = 1e-5
output2 = (x - torch.mean(x, dim=-1, keepdim=True)) / (torch.std(x, dim=-1, keepdim=True, unbiased=False) + eps)
output2

tensor([[[ 0.9622, -1.3471, -0.5773,  0.9622],
         [ 0.5773,  0.5773, -1.7320,  0.5773],
         [-0.5773, -1.3471,  0.9622,  0.9622]],

        [[ 1.0000,  1.0000, -1.0000, -1.0000],
         [-1.4142,  0.0000,  1.4142,  0.0000],
         [-0.5773, -0.5773,  1.7320, -0.5773]]])

In [3]:
gamma = torch.ones(n_embd) * 2
gamma

tensor([2., 2., 2., 2.])

In [4]:
output2 * gamma

tensor([[[ 1.9245, -2.6943, -1.1547,  1.9245],
         [ 1.1547,  1.1547, -3.4641,  1.1547],
         [-1.1547, -2.6943,  1.9245,  1.9245]],

        [[ 2.0000,  2.0000, -2.0000, -2.0000],
         [-2.8284,  0.0000,  2.8284,  0.0000],
         [-1.1547, -1.1547,  3.4640, -1.1547]]])

In [5]:
print(f"Outputs are equivalent: {torch.allclose(output, output2)}")

Outputs are equivalent: True


### 2. Feed Forward

* This operates on a per-token level, across the entire embedding space.
* Information from other tokens is gathered by the dot-product from the Attention.
* Then the model needs to "think" on that information it has gathered.

In [3]:
import torch
import torch.nn as nn

batch_size = 2; context_size = 3; n_embd = 4

x = torch.rand(size=(batch_size, context_size, n_embd))

class FeedForward(nn.Module):
    def __init__(self, n_embd, bias=False):
        super().__init__()
        self.linear1 = nn.Linear(n_embd, 4 * n_embd, bias=bias)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(4 * n_embd, n_embd, bias=bias)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        return self.linear2(x)

ffwd = FeedForward(n_embd)

In [5]:
print(f"Input {x=}")
print(f"Output {ffwd(x)=}")
print(f"Number params ={sum(p.numel() for p in ffwd.parameters() if p.requires_grad)}")

Input x=tensor([[[0.0611, 0.6234, 0.4217, 0.0938],
         [0.4592, 0.5698, 0.8631, 0.4320],
         [0.0084, 0.3418, 0.8917, 0.6578]],

        [[0.1050, 0.4519, 0.0675, 0.1947],
         [0.6866, 0.1235, 0.0039, 0.3434],
         [0.7643, 0.6961, 0.1230, 0.5648]]])
Output ffwd(x)=tensor([[[-0.0580, -0.0916, -0.0183,  0.0028],
         [-0.1717, -0.1305,  0.0164, -0.1736],
         [-0.1749, -0.1288, -0.0602, -0.1383]],

        [[-0.0338, -0.0571, -0.0499,  0.0172],
         [-0.0675, -0.0149,  0.0274, -0.1064],
         [-0.1100, -0.1116, -0.0275, -0.0852]]], grad_fn=<UnsafeViewBackward0>)
Number params =128
