In [1]:
import torch

# Aggregating values

In [2]:
torch.manual_seed(1337)

b = torch.randint(0,10,(3,2)).float()
b

tensor([[5., 7.],
        [2., 0.],
        [5., 3.]])

Aggregating values can be done via matrix multiplication. I.e., for summing along the column we can do:

In [3]:
a = torch.ones(3,3)
a

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [4]:
a @ b

tensor([[12., 10.],
        [12., 10.],
        [12., 10.]])

Doing a cumulative sum up works by using a triangular matrix:

In [5]:
diag = torch.tril(a)
diag

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])

In [6]:
diag @ b

tensor([[ 5.,  7.],
        [ 7.,  7.],
        [12., 10.]])

Which can be easily extended to an average computation by weighting the matrix entries:

In [7]:
diag /= diag.sum(dim=1, keepdim=True)
diag

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [8]:
diag @ b

tensor([[5.0000, 7.0000],
        [3.5000, 3.5000],
        [4.0000, 3.3333]])

## Using softmax

The same can be achieved by abusing the fact that $exp(-inf) = 0$, using softmax

In [9]:
a = torch.ones(3,3)
diag = torch.tril(a)
a.masked_fill_(diag==0, -torch.inf)
a = torch.softmax(a, dim=1)
a

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

# Self attention

In [23]:
torch.manual_seed(1337)

head_size = 16
B,T,C = 4,8,32
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 32])

In [24]:
key = torch.nn.Linear(C,head_size)
query = torch.nn.Linear(C,head_size)
value = torch.nn.Linear(C,head_size)

In [17]:
k = key(x)
q = query(x)
k.shape, q.shape

(torch.Size([4, 8, 16]), torch.Size([4, 8, 16]))

In [19]:
wei = k @ q.transpose(-2,-1) # B,T,16 @ B,16,T -> B,T,T
wei.shape

torch.Size([4, 8, 8])

In [21]:
tril = torch.tril(torch.ones(T,T))
wei.masked_fill_(tril==0, -torch.inf)
wei = torch.softmax(wei, dim=-1)

wei.shape

torch.Size([4, 8, 8])

In [26]:
v = value(x)
v.shape

torch.Size([4, 8, 16])

In [28]:
out = wei @ v
out.shape

torch.Size([4, 8, 16])