In [1]:
import torch
import torch.nn as nn
torch.manual_seed(1337)

<torch._C.Generator at 0x1f5a38fb6b0>

In [2]:
# where we stopped previously

# all weights are uniform and all previous tokens are treated equally

import torch.nn.functional as F

B, T, C = 4, 8, 2 
x = torch.randn(B, T, C)

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) 
wei = F.softmax(wei, dim=-1)
wei


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [3]:
# we do not want this weights to be uniform
# because each token is gonna find other tokens more or less interesting
# so our goal is to make this weights data dependent but not uniform

# how self attention solves this problem:
# every single token will emit 2 vectors: query and key
# query is sort of 'what Iam I looking for'?
# key is 'what I have to offer'? 
# value is 'what I have if you find me interesting'?
# so my query (if I'm a specific token) dot products with all the other tokens keys 
# and this dot product becomes wei 

# single head of self attention
head_size = 16
query = nn.Linear(C, head_size, bias=False)
key = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# here all tokens independently emit their query and key
# no token communucation happend yet
q = query(x) # [B, T, head_size]
k = key(x) # [B, T, head_size]

# here is token communication:
# all token queries dot product with all token keys
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # [B, T, head_size] @ [B, head_size, T] ====> [B, T, T]

# head_size**-0.5 added to reduce the variance of the dot product and make it around 1
# this normalization is important because we use softmax later
# if the dot product is too large, softmax will focus only on the largest value
# and instead of getting a distribution of weights, we will get a single big weight

wei[0]

tensor([[ 0.0518, -0.0279,  0.1603,  0.2428,  0.0102, -0.3134,  0.0783, -0.3488],
        [-0.1272, -0.2687, -0.3077, -0.4568,  0.3930,  0.4386, -0.2394,  1.2351],
        [ 0.1856, -0.0142,  0.5522,  0.8337, -0.0695, -1.0379,  0.2922, -1.3450],
        [ 0.2836, -0.0138,  0.8420,  1.2711, -0.1160, -1.5788,  0.4477, -2.0646],
        [ 0.1334,  0.3463,  0.3064,  0.4526, -0.4921, -0.3969,  0.2602, -1.3681],
        [-0.4107, -0.1093, -1.1863, -1.7872,  0.3284,  2.1592, -0.6664,  3.1349],
        [ 0.0643, -0.0820,  0.2109,  0.3205,  0.0716, -0.4351,  0.0904, -0.3791],
        [-0.2371,  0.5069, -0.8297, -1.2664, -0.5174,  1.8049, -0.3045,  1.1684]],
       grad_fn=<SelectBackward0>)

In [4]:
# and the rest is the same exept wei is not a zero matrix anymore
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf')) 
wei = F.softmax(wei, dim=-1)
wei

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [5.3533e-01, 4.6467e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [3.0658e-01, 2.5107e-01, 4.4235e-01, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.6194e-01, 1.2028e-01, 2.8305e-01, 4.3472e-01, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.8736e-01, 2.3183e-01, 2.2275e-01, 2.5782e-01, 1.0024e-01,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [5.4877e-02, 7.4173e-02, 2.5266e-02, 1.3854e-02, 1.1490e-01,
          7.1693e-01, 0.0000e+00, 0.0000e+00],
         [1.4379e-01, 1.2422e-01, 1.6649e-01, 1.8579e-01, 1.4484e-01,
          8.7269e-02, 1.4759e-01, 0.0000e+00],
         [5.7180e-02, 1.2032e-01, 3.1615e-02, 2.0428e-02, 4.3204e-02,
          4.4064e-01, 5.3452e-02, 2.3316e-01]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.00

In [5]:
v = value(x)
#out = wei @ x
out = wei @ v # [B, T, T] @ [B, T, head_size] ====> [B, T, head_size]
out[0]

tensor([[ 0.0660,  0.0865, -0.0022, -0.0979,  0.0494, -0.0847, -0.1617, -0.0495,
          0.1284,  0.1332,  0.0091,  0.0597,  0.1579, -0.0382,  0.0418, -0.0894],
        [-0.2420,  0.1175, -0.2201, -0.1949,  0.3228,  0.1415, -0.2377,  0.0728,
          0.0339,  0.2479,  0.1825, -0.0859,  0.2657, -0.0599, -0.1402, -0.2509],
        [ 0.0238,  0.1610, -0.0689, -0.2005,  0.1682, -0.0809, -0.3061, -0.0503,
          0.1969,  0.2677,  0.0678,  0.0612,  0.3090, -0.0734,  0.0190, -0.2049],
        [ 0.2719,  0.2343,  0.0551, -0.2476,  0.0619, -0.3013, -0.4325, -0.1734,
          0.3870,  0.3417, -0.0230,  0.2085,  0.4131, -0.1010,  0.1686, -0.2056],
        [ 0.1616,  0.1602,  0.0217, -0.1739,  0.0611, -0.1873, -0.2971, -0.1083,
          0.2543,  0.2386, -0.0032,  0.1304,  0.2863, -0.0697,  0.1009, -0.1501],
        [-0.6233, -0.2353, -0.2847,  0.1831,  0.2080,  0.5735,  0.4151,  0.3220,
         -0.5368, -0.2723,  0.2028, -0.3858, -0.3607,  0.0929, -0.3774,  0.0697],
        [ 0.1164,  0.0