In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [34]:
# efficient averaging and masking using matrix multiplication
a = torch.ones(3, 3)
print(a)

a = a.tril()
print(a)

a = a / a.sum(1, keepdim=True)
print(a)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])


In [36]:
# we can see how this works on an example
b = torch.randint(0, 10, (3, 3)).float()
print(b)

c = a @ b
print(c)

# notice how the averages go downward along the columns

tensor([[4., 8., 0.],
        [8., 5., 3.],
        [9., 9., 8.]])
tensor([[4.0000, 8.0000, 0.0000],
        [6.0000, 6.5000, 1.5000],
        [7.0000, 7.3333, 3.6667]])


In [32]:
# however, you can achieve the same result with softmax
T = 8

# triangular matrix
w = torch.tril(torch.ones(T, T))
print(w)

# mask with negative infinity
w = w.masked_fill(w == 0, float('-inf'))
print(w)

# softmax to get appropriate subdivisions
# dim = 0 avg over rows, dim = 1 avg over cols
w = F.softmax(w, dim=1)
print(w)


tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[1., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., -inf, -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., -inf, -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., -inf, -inf, -inf, -inf],
        [1., 1., 1., 1., 1., -inf, -inf, -inf],
        [1., 1., 1., 1., 1., 1., -inf, -inf],
        [1., 1., 1., 1., 1., 1., 1., -inf],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000,

In [46]:
# now we implement a basic self-attention block
torch.manual_seed(42)

B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)
print(x.shape)

# we start by initializing a single "head"
h_size = 16
key = nn.Linear(C, h_size, bias=False)
qry = nn.Linear(C, h_size, bias=False)

# key and query vectors are the results of applying linear 
# transformations (in the form of a linear nn) to the input 
# vectors, which projects them into key and query "spaces"
k = key(x) # (B, T, h_size)
q = qry(x) # (B, T, h_size)
print(k.shape, q.shape)
print(k.var(), q.var())

# we create the weights matrix by combining k and q
# (B, T, h_size) @ (B, h_size, T) --> (B, T, T)
# (convince yourself this is equivalent to dot product)
wei = q @ k.transpose(-2, -1)
print(wei.shape)

# scale by the sqrt of h_size (embedding space dimension)
# this softens the softmax, prevents converging on one-hot vectors
before = wei.var()
wei = wei * h_size**-0.5
print(before, wei.var())

# this preserves the variance of wei

# finally, we apply the previous triagularization/softmax technique
tri = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tri==0, float('-inf'))

# apply softmax over the last dimension (why?)
wei = F.softmax(wei, dim=-1)
print(wei.shape)
print(wei[0])

# we also do a projection into "value" space
val = nn.Linear(C, h_size, bias=False)
v = val(x)

# finally, multiple v by wei to get output
out = wei @ v
print(out.shape)

torch.Size([4, 8, 32])
torch.Size([4, 8, 16]) torch.Size([4, 8, 16])
tensor(0.3348, grad_fn=<VarBackward0>) tensor(0.3227, grad_fn=<VarBackward0>)
torch.Size([4, 8, 8])
tensor(1.9700, grad_fn=<VarBackward0>) tensor(0.1231, grad_fn=<VarBackward0>)
torch.Size([4, 8, 8])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4106, 0.5894, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3657, 0.2283, 0.4061, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2168, 0.2759, 0.2204, 0.2870, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2553, 0.1697, 0.1548, 0.2341, 0.1861, 0.0000, 0.0000, 0.0000],
        [0.1318, 0.2060, 0.1405, 0.1917, 0.1949, 0.1351, 0.0000, 0.0000],
        [0.2137, 0.0978, 0.2374, 0.1025, 0.1418, 0.0838, 0.1230, 0.0000],
        [0.0852, 0.1047, 0.0824, 0.1376, 0.1015, 0.1900, 0.1780, 0.1206]],
       grad_fn=<SelectBackward0>)
torch.Size([4, 8, 16])
