In [1]:
import torch
import torch.nn.functional as F
import torch.nn as nn

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x7ff53fffc590>

In [3]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [4]:
xbow = torch.zeros((B, T, C))
for batch in range(B):
    for time in range(T):
        prev = x[batch, :time+1]
        xbow[batch, time] = torch.mean(prev, 0)

In [5]:
xbow

tensor([[[ 1.9269,  1.4873],
         [ 1.4138, -0.3091],
         [ 1.1687, -0.6176],
         [ 0.8657, -0.8644],
         [ 0.5422, -0.3617],
         [ 0.3864, -0.5354],
         [ 0.2272, -0.5388],
         [ 0.1027, -0.3762]],

        [[ 1.6423, -0.1596],
         [ 0.5725,  0.1400],
         [ 0.1289,  0.4528],
         [ 0.2969,  0.7597],
         [ 0.4933,  0.8671],
         [ 0.5129,  0.9450],
         [ 0.4065,  0.8160],
         [ 0.3242,  0.8215]],

        [[-1.3847, -0.8712],
         [-0.8040,  0.4231],
         [-0.4297,  0.1405],
         [-0.2459, -0.0882],
         [-0.5082,  0.1285],
         [-0.5701,  0.0069],
         [-0.6707,  0.3092],
         [-0.7412,  0.2095]],

        [[-0.9138, -0.6581],
         [-0.4179, -0.0662],
         [-0.4413,  0.3530],
         [-0.5344,  0.0808],
         [-0.7082,  0.0718],
         [-0.6008,  0.1724],
         [-0.5289,  0.4113],
         [-0.6109,  0.5329]]])

## Same with matmul

In [6]:
tril = torch.tril(torch.ones(T, T))
tril /= tril.sum(1, keepdims=True)
xbow_matr = tril @ x
xbow_matr

tensor([[[ 1.9269,  1.4873],
         [ 1.4138, -0.3091],
         [ 1.1687, -0.6176],
         [ 0.8657, -0.8644],
         [ 0.5422, -0.3617],
         [ 0.3864, -0.5354],
         [ 0.2272, -0.5388],
         [ 0.1027, -0.3762]],

        [[ 1.6423, -0.1596],
         [ 0.5725,  0.1400],
         [ 0.1289,  0.4528],
         [ 0.2969,  0.7597],
         [ 0.4933,  0.8671],
         [ 0.5129,  0.9450],
         [ 0.4065,  0.8160],
         [ 0.3242,  0.8215]],

        [[-1.3847, -0.8712],
         [-0.8040,  0.4231],
         [-0.4297,  0.1405],
         [-0.2459, -0.0882],
         [-0.5082,  0.1285],
         [-0.5701,  0.0069],
         [-0.6707,  0.3092],
         [-0.7412,  0.2095]],

        [[-0.9138, -0.6581],
         [-0.4179, -0.0662],
         [-0.4413,  0.3530],
         [-0.5344,  0.0808],
         [-0.7082,  0.0718],
         [-0.6008,  0.1724],
         [-0.5289,  0.4113],
         [-0.6109,  0.5329]]])

## Same but with softmax softmax

In [7]:
wei = torch.zeros(T, T)
wei = wei.masked_fill(torch.tril(torch.ones(T, T)) == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [8]:
wei @ x

tensor([[[ 1.9269,  1.4873],
         [ 1.4138, -0.3091],
         [ 1.1687, -0.6176],
         [ 0.8657, -0.8644],
         [ 0.5422, -0.3617],
         [ 0.3864, -0.5354],
         [ 0.2272, -0.5388],
         [ 0.1027, -0.3762]],

        [[ 1.6423, -0.1596],
         [ 0.5725,  0.1400],
         [ 0.1289,  0.4528],
         [ 0.2969,  0.7597],
         [ 0.4933,  0.8671],
         [ 0.5129,  0.9450],
         [ 0.4065,  0.8160],
         [ 0.3242,  0.8215]],

        [[-1.3847, -0.8712],
         [-0.8040,  0.4231],
         [-0.4297,  0.1405],
         [-0.2459, -0.0882],
         [-0.5082,  0.1285],
         [-0.5701,  0.0069],
         [-0.6707,  0.3092],
         [-0.7412,  0.2095]],

        [[-0.9138, -0.6581],
         [-0.4179, -0.0662],
         [-0.4413,  0.3530],
         [-0.5344,  0.0808],
         [-0.7082,  0.0718],
         [-0.6008,  0.1724],
         [-0.5289,  0.4113],
         [-0.6109,  0.5329]]])

## Self-attention

In [21]:
def mask(inp: torch.Tensor) -> torch.Tensor:
    return inp.masked_fill(torch.tril(torch.ones(T, T)) == 0, float('-inf'))

In [22]:
head_size = 16
# Linear projections to key an query space
key = nn.Linear(C, head_size, bias=False) # (B, T, head_size)
query = nn.Linear(C, head_size, bias=False)# (B, T, head_size)
value = nn.Linear(C, head_size, bias=False)# (B, T, head_size)
# get k, q
k = key(x)
q = query(x)
v = value(x)
qk =  q @ k.transpose(-1, -2) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)
wei_masked = F.softmax(mask(qk), -1)
wei_masked @ x # (B, T, T) @ (B, T, head_size) -> (B, T, head_size)

tensor([[[ 1.9269,  1.4873],
         [ 1.2838, -0.7644],
         [ 1.1342, -0.6618],
         [ 0.2858, -1.5342],
         [ 1.0300, -0.3864],
         [-0.3604, -0.2428],
         [-0.3803, -0.5656],
         [ 0.0351, -0.6497]],

        [[ 1.6423, -0.1596],
         [ 0.4623,  0.1708],
         [ 0.4445,  0.3149],
         [ 1.4839,  0.1545],
         [ 1.4105,  0.5621],
         [ 1.3007,  0.6685],
         [ 0.2506,  0.7846],
         [ 0.6968,  0.7907]],

        [[-1.3847, -0.8712],
         [-0.4140,  1.2924],
         [-0.4378,  0.2068],
         [-0.4467, -0.0978],
         [-0.9661, -0.1311],
         [-1.2174, -0.0925],
         [-0.2688, -0.1308],
         [-1.2619,  0.0074]],

        [[-0.9138, -0.6581],
         [-0.2679,  0.1129],
         [-0.3532,  0.4092],
         [-0.7504, -0.3459],
         [-1.0265, -0.2122],
         [-0.4563,  0.2896],
         [-0.2058,  0.7832],
         [-0.6481,  0.3489]]], grad_fn=<UnsafeViewBackward0>)

In [23]:
wei_masked[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3733, 0.6267, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3098, 0.3104, 0.3798, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0232, 0.1486, 0.1982, 0.6300, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3616, 0.2648, 0.2014, 0.1236, 0.0487, 0.0000, 0.0000, 0.0000],
        [0.0036, 0.0378, 0.0490, 0.1842, 0.3950, 0.3304, 0.0000, 0.0000],
        [0.0060, 0.0481, 0.0543, 0.1520, 0.2025, 0.2319, 0.3053, 0.0000],
        [0.0727, 0.1466, 0.1307, 0.1509, 0.0969, 0.1522, 0.1381, 0.1120]],
       grad_fn=<SelectBackward0>)

In [44]:
q = torch.randn((B, T, head_size))
k = torch.randn((B, T, head_size))
(q@k.transpose(-2, -1)).var(), (q @ k.transpose(-2, -1) / (head_size ** 0.5)).var()

(tensor(18.1281), tensor(1.1330))