In [17]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [18]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [19]:
# version 1
xbow = torch.zeros((B, T, C)) # bag of words
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1] # [t, C]
        xbow[b, t] = torch.mean(xprev, 0)

In [20]:
x[0]

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])

In [21]:
xbow[0]

tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])

In [22]:
# version 2
torch.manual_seed(42)
wei = torch.tril(torch.ones(T, T)) # weights
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # [B, T, T] @ [B, T, C] -----> [B, T, C]
torch.allclose(xbow, xbow2)

True

In [23]:
# version 3
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T)) # interaction strength, affinity with past tokens
wei = wei.masked_fill(tril == 0, float('-inf')) # replace 0 with negative inf
wei = F.softmax(wei, dim=-1) # along rows
xbow3 = wei @ x
torch.allclose(xbow, xbow3)

True

In [71]:
# version 4: self-attention (keys, queries and values come from same source - x, so nodes are self-attending)
torch.manual_seed(1337)
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)

# One head performing self-attention
head_size = 16
# initialized with different random weights, each token emits two vectors (k and q)
# note: batches never mix, never communicate, are calculated in parallel
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # [B, T, 16]
q = query(x) # [B, T, 16]
wei = q @ k.transpose(-2, -1) # [B, T, 16] @ [B, 16, T] ---> [B, T, T]; attention scores

tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) # future cannot communicate with past
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v
# out = wei @ x

In [None]:
# attention is a communication mechanism
# can be seen as nodes in a directed graph

In [74]:
# scaled attention - control variance at initialization
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5

tensor(1.0065)
