# Mathematical Trick

In [11]:
import torch

In [3]:
B,T,C = 3,5,2
x = torch.zeros(B,T,C)
for b in range(B):
    for t in range(T):
        x[b,t,:] = b*10+t

print(x.shape)
print(x[:,:,0])  # B, T
print(x[0,:,:])  # T, C

torch.Size([3, 5, 2])
tensor([[ 0.,  1.,  2.,  3.,  4.],
        [10., 11., 12., 13., 14.],
        [20., 21., 22., 23., 24.]])
tensor([[0., 0.],
        [1., 1.],
        [2., 2.],
        [3., 3.],
        [4., 4.]])


In [4]:
y = torch.zeros_like(x)
for t in range(T):
    y[:,t,:] = x[:,0:t+1,:].mean(dim=1)
print(y[:,:,0])  # B, T
print(y[0,:,:])  # T, C

tensor([[ 0.0000,  0.5000,  1.0000,  1.5000,  2.0000],
        [10.0000, 10.5000, 11.0000, 11.5000, 12.0000],
        [20.0000, 20.5000, 21.0000, 21.5000, 22.0000]])
tensor([[0.0000, 0.0000],
        [0.5000, 0.5000],
        [1.0000, 1.0000],
        [1.5000, 1.5000],
        [2.0000, 2.0000]])


In [5]:
t = torch.tril(torch.ones((T,T)))
w = t / t.sum(dim=1, keepdim=True)
print(w)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])


In [6]:
y2 = w @ x
print(y2[:,:,0])  # B, T
print(y2[0,:,:])  # T, C
print((y==y2).all())

tensor([[ 0.0000,  0.5000,  1.0000,  1.5000,  2.0000],
        [10.0000, 10.5000, 11.0000, 11.5000, 12.0000],
        [20.0000, 20.5000, 21.0000, 21.5000, 22.0000]])
tensor([[0.0000, 0.0000],
        [0.5000, 0.5000],
        [1.0000, 1.0000],
        [1.5000, 1.5000],
        [2.0000, 2.0000]])
tensor(True)


In [9]:
# Method 3
t = torch.tril(torch.ones((T,T)))
w = torch.zeros((T,T))                     # weights, this will comefrom key and query
w = w.masked_fill(t==0, float('-inf'))     # mask future tokens
w = torch.softmax(w, dim=-1)               # normalize to 0..1 sum=1
print(w)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]])


In [10]:
y3 = w @ x                     # weighted sum of token info
print(y3[:,:,0])  # B, T
print(y3[0,:,:])  # T, C
print((y==y3).all())

tensor([[ 0.0000,  0.5000,  1.0000,  1.5000,  2.0000],
        [10.0000, 10.5000, 11.0000, 11.5000, 12.0000],
        [20.0000, 20.5000, 21.0000, 21.5000, 22.0000]])
tensor([[0.0000, 0.0000],
        [0.5000, 0.5000],
        [1.0000, 1.0000],
        [1.5000, 1.5000],
        [2.0000, 2.0000]])
tensor(True)


# Self-Attention

In [1]:
import torch

In [6]:
B,T,C,H = 3,5,2,4
x = torch.zeros(B,T,C)
for b in range(B):
    for t in range(T):
        x[b,t,:] = b*10+t

print(x.shape)
print(x[:,:,0])  # B, T
print(x[0,:,:])  # T, C

torch.Size([3, 5, 2])
tensor([[ 0.,  1.,  2.,  3.,  4.],
        [10., 11., 12., 13., 14.],
        [20., 21., 22., 23., 24.]])
tensor([[0., 0.],
        [1., 1.],
        [2., 2.],
        [3., 3.],
        [4., 4.]])


In [8]:
W_k = torch.randn(C, H) / C**0.5
W_q = torch.randn(C, H) / C**0.5
W_v = torch.randn(C, H) / C**0.5

x_k = x @ W_k            # B,T,H
x_q = x @ W_q            # B,T,H
x_v = x @ W_v            # B,T,H
print(f"{x_k.shape=}")
print(f"{x_q.shape=}")
print(f"{x_v.shape=}")

#print(x_q @ x_k.mT.shape)
W_a = x_q @ x_k.mT / H**0.5     # B,T,T affinity  <-  B,T,H @ B,H,T
print(f"{W_a.shape=}")

# Method 3
t = torch.tril(torch.ones((T,T)))
W_a = W_a.masked_fill(t==0, float('-inf'))     # B,T,T mask future tokens
W_a = torch.softmax(W_a, dim=-1)               # B,T,T normalize to 0..1 sum=1

y = W_a @ x_v           # B,T,H <- T,T @ B,T,H weighted sum of token info
print(f"{y.shape=}")
print(y[:,:,0])         # B, T
print(y[0,:,:])         # T, H

x_k.shape=torch.Size([3, 5, 4])
x_q.shape=torch.Size([3, 5, 4])
x_v.shape=torch.Size([3, 5, 4])
W_a.shape=torch.Size([3, 5, 5])
y.shape=torch.Size([3, 5, 4])
tensor([[ 0.0000,  0.2316,  0.5068,  0.8469,  1.2597],
        [ 4.3678,  4.7136,  5.1254,  5.5677,  6.0177],
        [ 8.7356,  9.1407,  9.5771, 10.0174, 10.4576]])
tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.2316,  0.3265, -0.5058, -0.3795],
        [ 0.5068,  0.7143, -1.1066, -0.8304],
        [ 0.8469,  1.1937, -1.8493, -1.3876],
        [ 1.2597,  1.7756, -2.7508, -2.0641]])
