# Self attention example implementations

Query is "what am I looking for?"
Key is "what do I contain?"

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(1337)

<torch._C.Generator at 0x119dedef0>

In [4]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
wei = torch.tril(torch.ones(T, T))
wei

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [5]:
wei = wei / wei.sum(1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [6]:
xbow = wei @ x # (T, T) x (B, T, C) = (B, T, C), for dimension B is preserved
xbow

tensor([[[-2.0555,  1.8275],
         [-0.3760,  0.6887],
         [ 0.1984,  1.0228],
         [ 0.1177,  0.3465],
         [ 0.0888,  0.2920],
         [ 0.2493,  0.3563],
         [ 0.2575,  0.1987],
         [ 0.3182,  0.2848]],

        [[ 2.2874,  0.9611],
         [ 0.3789,  0.3350],
         [ 0.2146,  0.1187],
         [ 0.0036,  0.3737],
         [-0.1954,  0.3329],
         [ 0.0413,  0.2384],
         [-0.1156,  0.1108],
         [ 0.0977,  0.0096]],

        [[-0.8961,  0.0662],
         [-0.4762,  1.2037],
         [-1.2253,  0.9724],
         [-1.1226,  0.6678],
         [-0.8971,  0.9437],
         [-0.7739,  0.7500],
         [-0.8565,  0.6346],
         [-0.9811,  0.3822]],

        [[-0.3454, -1.1625],
         [-0.1005, -0.4981],
         [ 0.1833, -0.0277],
         [-0.2945,  0.3056],
         [-0.0437,  0.4565],
         [ 0.0685,  0.1660],
         [-0.0395,  0.4477],
         [ 0.0294,  0.5441]]])

### Third way (one we will use)

In [None]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x 
xbow3

### Single self-attention head

In [13]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

# Single head
head_size = 16

# initialized weights between key and query will be different
# thus yeilding differeing values for each item in the batch
key = nn.Linear(C, head_size, bias=False) # inits with weights here
query = nn.Linear(C, head_size, bias=False)
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)

# For each item in the batch, we'll have a TxT affinity matrix
wei = q @ k.transpose(-2, -1) # (B, T, 16) x (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow_qk = wei @ x # (B, T, T) x (B, T, C) = (B, T, C), for dimension B is preserved
xbow_qk

tensor([[[ 3.9814e-01, -9.0421e-01],
         [-1.1518e+00,  2.1249e+00],
         [ 1.4696e+00, -3.2129e+00],
         [ 2.2503e+00, -4.9086e+00],
         [ 1.2415e+00, -2.2109e+00],
         [-3.3254e+00,  7.0715e+00],
         [ 4.6905e-01, -1.1343e+00],
         [-1.6239e+00,  4.2404e+00]],

        [[ 2.4597e+01,  7.7186e+01],
         [-1.2764e+01, -4.2438e+01],
         [-4.2250e+01, -1.3326e+02],
         [ 1.3654e+01,  4.1061e+01],
         [ 3.0302e+01,  9.8804e+01],
         [-1.2934e+01, -3.9497e+01],
         [ 2.1814e+01,  7.0265e+01],
         [-2.0398e+01, -7.0405e+01]],

        [[-1.0258e+00, -1.1936e+00],
         [ 1.9371e+00,  2.5300e+00],
         [-1.2944e+00, -2.6407e+00],
         [-1.9841e+00, -2.4554e+00],
         [ 2.6350e-01,  1.0794e+00],
         [ 3.2852e+00,  5.6206e+00],
         [-3.3397e+00, -4.9728e+00],
         [-2.5425e-01,  1.9779e-01]],

        [[ 1.8216e+00, -2.2928e-01],
         [ 1.3755e+00, -1.9726e-01],
         [-2.1019e+00,  1.7857e-