# Attention (math trick)

In [1]:
import torch

In [2]:
B, T, C = 4, 8, 2

x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [3]:
x_bow = torch.zeros((B, T, C))
for b in range(B):
    for t in range(T):
        x_prev = x[b, :t+1]
        x_bow[b, t] = torch.mean(x_prev, 0)  # average all previous values together (channels separate).

In [4]:
x[0], x_bow[0]

(tensor([[-0.9963, -1.1696],
         [ 1.2406,  0.0282],
         [-0.1587,  0.3913],
         [ 1.1970, -0.8037],
         [-0.2580, -0.5031],
         [ 0.5135,  0.6065],
         [-1.7594,  0.6156],
         [-0.7402,  0.3914]]),
 tensor([[-0.9963, -1.1696],
         [ 0.1221, -0.5707],
         [ 0.0285, -0.2500],
         [ 0.3206, -0.3885],
         [ 0.2049, -0.4114],
         [ 0.2563, -0.2417],
         [-0.0316, -0.1193],
         [-0.1202, -0.0554]]))

In [5]:
a = torch.tril(torch.ones(3, 3, dtype=torch.float32))
a = a / a.sum(dim=-1, keepdim=True)

b = torch.randint(0, 10, size=(3, 2), dtype=torch.float32)
c = a @ b

a, b, c

(tensor([[1.0000, 0.0000, 0.0000],
         [0.5000, 0.5000, 0.0000],
         [0.3333, 0.3333, 0.3333]]),
 tensor([[8., 3.],
         [7., 3.],
         [2., 0.]]),
 tensor([[8.0000, 3.0000],
         [7.5000, 3.0000],
         [5.6667, 2.0000]]))

In [6]:
x.shape

torch.Size([4, 8, 2])

In [7]:
wei = torch.tril(torch.ones((T, T), dtype=torch.float32))
wei = wei / wei.sum(dim=-1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [8]:
x_bow2 = wei @ x  # (B, T, T) @ (B, T, C) -> (B, T, C)
x_bow2

tensor([[[-0.9963, -1.1696],
         [ 0.1221, -0.5707],
         [ 0.0285, -0.2500],
         [ 0.3206, -0.3885],
         [ 0.2049, -0.4114],
         [ 0.2563, -0.2417],
         [-0.0316, -0.1193],
         [-0.1202, -0.0554]],

        [[-0.2814,  1.0161],
         [-0.0381,  1.6677],
         [-0.3430,  1.5716],
         [-0.4695,  1.2059],
         [-0.5404,  1.0872],
         [-0.6191,  0.8142],
         [-0.6951,  0.8141],
         [-0.4750,  1.0751]],

        [[-1.0570, -0.1697],
         [-0.2733,  0.3190],
         [-0.6672,  0.8230],
         [-0.6174,  1.0370],
         [-0.4288,  1.1590],
         [-0.3252,  0.9189],
         [-0.0913,  0.7229],
         [ 0.1616,  0.5739]],

        [[ 0.3954, -0.0627],
         [-0.2495,  0.5124],
         [-0.1499,  0.8577],
         [-0.2675,  0.6952],
         [-0.2711,  0.6007],
         [-0.1673,  0.3681],
         [ 0.0488,  0.4644],
         [ 0.1968,  0.3429]]])

In [9]:
x_bow.allclose(x_bow2)

True

In [10]:
tril = torch.tril(torch.ones((T, T), dtype=torch.float32))

wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float("-inf"))

wei

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [11]:
import torch.nn.functional as F

wei = F.softmax(wei, dim=-1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [12]:
x_bow3 = wei @ x
torch.allclose(x_bow, x_bow3)

True

# Attention

In [13]:
B, T, C = 4, 8, 32

token_embeddings = torch.randn(B, T, C)
token_embeddings.shape

torch.Size([4, 8, 32])

In [16]:
import torch.nn as nn

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(token_embeddings)  # B, T, head_size
q = query(token_embeddings)
v = value(token_embeddings)

k.shape, q.shape, v.shape

(torch.Size([4, 8, 16]), torch.Size([4, 8, 16]), torch.Size([4, 8, 16]))

In [17]:
wei = q @ k.transpose(-2, -1)  # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
# Affinities between each key and each query!.
wei.shape

torch.Size([4, 8, 8])

In [18]:
# Scale wei.
wei = wei * (head_size ** -0.5)

In [19]:
tril = torch.tril(torch.ones(T, T, dtype=torch.int64))
wei = wei.masked_fill(tril == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
wei.shape

torch.Size([4, 8, 8])

In [20]:
wei[0]  # affinities between each key and query, with no future-peeking and smoothed.

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4600, 0.5400, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2981, 0.3809, 0.3210, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2081, 0.1776, 0.3219, 0.2924, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1320, 0.1445, 0.2468, 0.2143, 0.2624, 0.0000, 0.0000, 0.0000],
        [0.1487, 0.3262, 0.1134, 0.1219, 0.1106, 0.1792, 0.0000, 0.0000],
        [0.1744, 0.1945, 0.1255, 0.1502, 0.0901, 0.1157, 0.1496, 0.0000],
        [0.0898, 0.2187, 0.1074, 0.1153, 0.0945, 0.0834, 0.1602, 0.1307]],
       grad_fn=<SelectBackward0>)

In [21]:
v[0, :, :2]

tensor([[ 1.2951, -0.2425],
        [ 0.1870,  0.3171],
        [ 0.3554, -0.5937],
        [ 0.7292, -0.1772],
        [ 0.6960, -0.3666],
        [-0.5648, -0.0270],
        [ 0.3143,  0.3055],
        [ 0.6276,  0.1756]], grad_fn=<SliceBackward0>)

In [22]:
out = wei @ v  # weighted mean (by k-q affinities) of values. 
out.shape  # (B, T, head_size)

torch.Size([4, 8, 16])

# Layer Norm

In [23]:
x = torch.randn((32, 100))
x.shape

torch.Size([32, 100])

In [24]:
y = torch.nn.LayerNorm(100)(x)
y.shape

torch.Size([32, 100])

In [25]:
y[:,0].mean(), y[:,0].std()

(tensor(-0.1692, grad_fn=<MeanBackward0>),
 tensor(0.8966, grad_fn=<StdBackward0>))

In [26]:
y[0, :].mean(), y[0, :].std()

(tensor(0., grad_fn=<MeanBackward0>), tensor(1.0050, grad_fn=<StdBackward0>))

# Muti-Head Attention

In [27]:
a = torch.arange(2 * 3 * 4).view(2, 3, 4)
a

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [28]:
b = a.view(2, 3, 2, 2).transpose(1, 2)
b

tensor([[[[ 0,  1],
          [ 4,  5],
          [ 8,  9]],

         [[ 2,  3],
          [ 6,  7],
          [10, 11]]],


        [[[12, 13],
          [16, 17],
          [20, 21]],

         [[14, 15],
          [18, 19],
          [22, 23]]]])

In [29]:
b.transpose(1, 2).view(2, 3, 4)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])