# Math Behind the Self-Attention

In [29]:
import torch
import torch.nn as nn
from torch.nn import functional as F

## Toy Example

In [3]:
torch.manual_seed(42)
B, T, C = 4, 8, 2 # batch, time, channels
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

### Average of the Token with its Past Tokens

#### Using For-Loop to Get the Average

In [6]:
xbow = torch.zeros((B, T, C)) # Bag-of-Word
for b in range(B):
    for t in range(T):
        x_prev = x[b, :t+1] # t X C
        xbow[b, t] = torch.mean(x_prev, 0)
xbow.shape

torch.Size([4, 8, 2])

#### Dot Product to Get the Average

##### Sum of All Tokens

In [9]:
torch.manual_seed(42)
a = torch.ones(3, 3)
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[14., 16.],
        [14., 16.],
        [14., 16.]])


##### Sum of Tokens Up to Token T

In [10]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[ 2.,  7.],
        [ 8., 11.],
        [14., 16.]])


##### Average ot Tokens Up to Token T (Normalize the Matrix a)

In [11]:
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim = True) # keepdim=True so the broadcasting works
b = torch.randint(0, 10, (3, 2)).float()
c = a @ b
print(a)
print(b)
print(c)

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


##### Matrix Multiplication for Weight Aggregation

In [12]:
# Get information from preceding tokens using tril
weight = torch.tril(torch.ones(T, T))
weight = weight / weight.sum(1, keepdim=True)
# Next line: PyTorch changes (T, T)@(B, T, C) to (B, T, T)@(B, T, C) and applies (T, T)@(T, C) for each batch -> (B, T, C)
xbow2 = weight @ x 
torch.allclose(xbow, xbow2)

True

##### Softmax

In [20]:
weight_test = torch.zeros((T, T))
weight_test = weight_test.masked_fill(tril == 0, float('-inf'))
weight_test

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])

In [22]:
weight_test = F.softmax(weight_test, dim=-1)
weight_test

tensor([[0.2797, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029, 0.1029],
        [0.1773, 0.1773, 0.1076, 0.1076, 0.1076, 0.1076, 0.1076, 0.1076],
        [0.1519, 0.1519, 0.1519, 0.1089, 0.1089, 0.1089, 0.1089, 0.1089],
        [0.1405, 0.1405, 0.1405, 0.1405, 0.1095, 0.1095, 0.1095, 0.1095],
        [0.1341, 0.1341, 0.1341, 0.1341, 0.1341, 0.1098, 0.1098, 0.1098],
        [0.1300, 0.1300, 0.1300, 0.1300, 0.1300, 0.1300, 0.1100, 0.1100],
        [0.1271, 0.1271, 0.1271, 0.1271, 0.1271, 0.1271, 0.1271, 0.1102],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [17]:
tril = torch.tril(torch.ones(T, T))
weight = torch.zeros((T, T))
weight = weight.masked_fill(tril == 0, float('-inf'))
weight = F.softmax(weight, dim=-1)
xbow3 = weight @ x
torch.allclose(xbow, xbow3)

True

### Self-Attention

In [37]:
torch.manual_seed(42)
B, T, C = 4, 8, 32
x = torch.randn(B, T, C) # toy data of B batches, T tokens, and C channels

# Single head
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)      # (B, T, 16), this is what I have
q = query(x)    # (B, T, 16), this is what I am looking for
weight = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) -> (B, T, T)
print(f"Shape of weight=q@k.T is (B, T, T): {weight.shape}")

tril = torch.tril(torch.ones(T, T))
# weight = torch.zeros((T, T))
weight = weight.masked_fill(tril == 0, float('-inf')) 
weight = F.softmax(weight, dim = -1)

v = value(x) # If you find me interesting, this is what I communicate to you
output = weight@v
# output = weight @ x
print(f"Shape of weight@v is (B, T, head_size): {output.shape}")

Shape of weight=q@k.T is (B, T, T): torch.Size([4, 8, 8])
Shape of weight@v is (B, T, head_size): torch.Size([4, 8, 16])


In [32]:
weight[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1905, 0.8095, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3742, 0.0568, 0.5690, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1288, 0.3380, 0.1376, 0.3956, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4311, 0.0841, 0.0582, 0.3049, 0.1217, 0.0000, 0.0000, 0.0000],
        [0.0537, 0.3205, 0.0694, 0.2404, 0.2568, 0.0592, 0.0000, 0.0000],
        [0.3396, 0.0149, 0.5165, 0.0180, 0.0658, 0.0080, 0.0373, 0.0000],
        [0.0165, 0.0375, 0.0144, 0.1120, 0.0332, 0.4069, 0.3136, 0.0660]],
       grad_fn=<SelectBackward0>)