In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
# consider the following toy example:

torch.manual_seed(1337)

B, T, C = 4, 8, 2  # batch, time, channels
x = torch.randn(B, T, C)
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [3]:
# naive way
# We want x[b,t] = mean_{i<=t} x[b,i]
# bow: bag of words
xbow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        xprev = x[b, : t + 1]  # (t,C)
        xbow[b, t] = torch.mean(xprev, 0)

xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [4]:
# version 2: using matrix multiply for a weighted aggregation
weight = torch.tril(torch.ones((T, T)))
weight = weight / weight.sum(1, keepdim=True)
print(weight)

# (T x T) @ (B x T x C) -> (B x T x C)
xbow2 = weight @ x
# xbow2[0]
torch.allclose(xbow, xbow2, atol=1e-4)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

In [5]:
# version 3: use Softmax
weight = torch.zeros((T, T))
tril = torch.tril(torch.ones((T, T)))
weight = weight.masked_fill(tril == 0, float("-inf"))
weight = weight.softmax(dim=-1)
print(weight)

xbow3 = weight @ x
# xbow2[0]
torch.allclose(xbow, xbow3, atol=1e-4)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

$$
Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
$$

In [6]:
# version 4: self-attention!

torch.manual_seed(1337)
B, T, C = 4, 8, 32  # batch, time, channels
x = torch.randn(B, T, C)

# let's see a single Head perform self-attention
head_dim = 16
q_proj = nn.Linear(C, head_dim, bias=False)
k_proj = nn.Linear(C, head_dim, bias=False)
v_proj = nn.Linear(C, head_dim, bias=False)

# (B x T x C) @ (C x head_dim) -> (B x T x head_dim)
# what to look for
query_states = q_proj(x)
# what it contains/has
key_states = k_proj(x)
# what will give
value_states = v_proj(x)

# (B x T x C) @ (B x C x T) -> (B x T x T)
weight = query_states @ key_states.transpose(-1, -2)
# norm
weight = weight * head_dim**-0.5

# mask
# sometimes use addition as per in HF: https://github.com/huggingface/transformers/blob/2ffef0d0c7a6cfa5a59c4b883849321f66c79d62/src/transformers/modeling_utils.py#L228
tril = torch.tril(torch.ones(T, T))
weight = weight.masked_fill(tril == 0, float("-inf"))
# softmax
weight = F.softmax(weight, dim=-1)
print(weight[0])

# (B x T x T) @ (B x T x head_dim) -> (B x T x head_dim)
out = weight @ value_states

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5221, 0.4779, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3602, 0.3210, 0.3188, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2980, 0.4039, 0.1578, 0.1404, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1643, 0.1243, 0.1678, 0.1865, 0.3570, 0.0000, 0.0000, 0.0000],
        [0.2656, 0.2110, 0.1137, 0.1214, 0.2018, 0.0865, 0.0000, 0.0000],
        [0.1761, 0.1327, 0.1371, 0.0974, 0.1476, 0.1918, 0.1173, 0.0000],
        [0.1046, 0.1260, 0.0922, 0.0906, 0.1476, 0.1588, 0.1432, 0.1371]],
       grad_fn=<SelectBackward0>)


Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- **Each example across batch dimension is of course processed completely independently and never "talk" to each other**
- In an "encoder" attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
- "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides `weight` by 1/sqrt(head_size). **This makes it so when input Q,K are unit variance, weight will be unit variance too and Softmax will stay diffuse and not saturate too much.** Illustration below

In [7]:
# why divide by head_dim ** 0.5
q = torch.randn(B, T, head_dim)
k = torch.randn(B, T, head_dim)

weight_wo_norm = q @ k.transpose(-2, -1)
weight = q @ k.transpose(-2, -1) * head_dim**-0.5

print(q.var())
print(k.var())
# huge variance
print(weight_wo_norm.var())
print(weight.var())


tensor(1.0449)
tensor(1.0700)
tensor(17.4690)
tensor(1.0918)


In [8]:
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1))
# when var too big, gets too peaky, converges to one-hot
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 8, dim=-1))


tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


In [9]:
# batch norm
class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.1):
        self.eps = eps
        self.momentum = momentum
        self.training = True
        # parameters (trained with backprop)
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)
        # buffers (trained with a running 'momentum update')
        self.running_mean = torch.zeros(dim)
        self.running_var = torch.ones(dim)

    def __call__(self, x):
        # calculate the forward pass
        if self.training:
            xmean = x.mean(0, keepdim=True)  # batch mean
            xvar = x.var(0, keepdim=True)  # batch variance
        else:
            xmean = self.running_mean
            xvar = self.running_var
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)  # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        # update the buffers
        if self.training:
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]


torch.manual_seed(1337)
batch_norm = BatchNorm1d(100)
x = torch.randn(32, 100)  # batch size 32 of 100-dimensional vectors
x = batch_norm(x)

print(x.shape)
# mean,std of one feature across all batch inputs
print(x[:, 0].mean(), x[:, 0].std())
# mean,std of a single input from the batch, of its features
print(x[0, :].mean(), x[0, :].std())


torch.Size([32, 100])
tensor(7.4506e-09) tensor(1.0000)
tensor(0.0411) tensor(1.0431)


In [10]:
class LayerNorm1d:
    def __init__(self, dim, eps=1e-5):
        self.eps = eps
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):
        # calculate the forward pass
        # bh: only change compare to batchnorm
        # bh: normalize the rows instead of columns
        xmean = x.mean(1, keepdim=True)  # batch mean
        xvar = x.var(1, keepdim=True)  # batch variance
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)  # normalize to unit variance
        self.out = self.gamma * xhat + self.beta
        return self.out

    def parameters(self):
        return [self.gamma, self.beta]


torch.manual_seed(1337)
layer_norm = LayerNorm1d(100)
x = torch.randn(32, 100)  # batch size 32 of 100-dimensional vectors
x = layer_norm(x)

print(x.shape)
# mean,std of one feature across all batch inputs
print(x[:, 0].mean(), x[:, 0].std())
# mean,std of a single input from the batch, of its features
print(x[0, :].mean(), x[0, :].std())

torch.Size([32, 100])
tensor(0.1469) tensor(0.8803)
tensor(-9.5367e-09) tensor(1.0000)
