In [73]:
import numpy as np
import torch

import torch.nn.functional as F
import torch.nn as nn

In [2]:
torch.manual_seed(1337)

<torch._C.Generator at 0x2448efc0cf0>

#### self attention mechanism

In [44]:
B, T, C = 4, 8, 32  # batch, tokens, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 32])

#### slow for-loops

In [54]:
xbow = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1]  # (t, C)
        xbow[b, t] = torch.mean(xprev, 0)

#### fast vector multiplication

In [55]:
W = torch.tril(torch.ones(T, T)) # TxT
W /= torch.sum(W, 1, keepdims=True)

In [56]:
xbow2 = W @ x

#### now with softmax

In [70]:
tril = torch.tril(torch.ones(T, T))
W = torch.zeros((T, T))  # defines affinity between nodes
W = W.masked_fill(tril == 0, float('-inf'))  # prevent looking into future
W = F.softmax(W, dim=-1)

In [74]:
xbow3 = W @ x

#### single head implementation

In [75]:
head_size = 16
key = nn.Linear(C, head_size, bias=False)    # what do i contain?
query = nn.Linear(C, head_size, bias=False)  # what am i looking for?
value = nn.Linear(C, head_size, bias=False)  # what do i pass to you?

In [76]:
k = key(x)  # (B, T, head_size)
q = query(x)  # (B, T, head_size)

In [79]:
W = q @ k.transpose(-2, -1)  # affinity (B,T,H) @ (B,H,T) --> (B,T,T)

In [80]:
W.shape

torch.Size([4, 8, 8])

In [82]:
tril = torch.tril(torch.ones(T, T))
W = W.masked_fill(tril == 0, float('-inf'))  # prevent looking into future
W = F.softmax(W, dim=-1)

In [83]:
v = value(x)
xbow4 = W @ v

In [85]:
W[0]

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [5.4949e-01, 4.5051e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [3.0814e-02, 1.8298e-02, 9.5089e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [2.7306e-02, 5.9927e-03, 9.6422e-01, 2.4795e-03, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [1.0356e-01, 8.4526e-02, 2.9864e-01, 4.2126e-01, 9.2020e-02, 0.0000e+00,
         0.0000e+00, 0.0000e+00],
        [5.5813e-02, 2.6597e-01, 6.9153e-02, 8.4581e-02, 2.7192e-01, 2.5256e-01,
         0.0000e+00, 0.0000e+00],
        [1.1987e-02, 3.8042e-04, 9.6417e-01, 3.2126e-03, 5.0234e-03, 1.4995e-02,
         2.3348e-04, 0.0000e+00],
        [3.9488e-02, 2.8300e-01, 1.5332e-02, 7.4047e-02, 5.5218e-02, 4.9383e-01,
         1.6254e-02, 2.2827e-02]], grad_fn=<SelectBackward0>)