# Self attention example implementations

Query is "what am I looking for?"  
Key is "what do I contain?"

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

torch.manual_seed(1337)

<torch._C.Generator at 0x119dedef0>

In [15]:
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
wei = torch.tril(torch.ones(T, T))
wei

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [16]:
wei = wei / wei.sum(1, keepdim=True)
wei

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [17]:
xbow = wei @ x # (T, T) x (B, T, C) = (B, T, C), for dimension B is preserved
xbow

tensor([[[-2.0555,  1.8275],
         [-0.3760,  0.6887],
         [ 0.1984,  1.0228],
         [ 0.1177,  0.3465],
         [ 0.0888,  0.2920],
         [ 0.2493,  0.3563],
         [ 0.2575,  0.1987],
         [ 0.3182,  0.2848]],

        [[ 2.2874,  0.9611],
         [ 0.3789,  0.3350],
         [ 0.2146,  0.1187],
         [ 0.0036,  0.3737],
         [-0.1954,  0.3329],
         [ 0.0413,  0.2384],
         [-0.1156,  0.1108],
         [ 0.0977,  0.0096]],

        [[-0.8961,  0.0662],
         [-0.4762,  1.2037],
         [-1.2253,  0.9724],
         [-1.1226,  0.6678],
         [-0.8971,  0.9437],
         [-0.7739,  0.7500],
         [-0.8565,  0.6346],
         [-0.9811,  0.3822]],

        [[-0.3454, -1.1625],
         [-0.1005, -0.4981],
         [ 0.1833, -0.0277],
         [-0.2945,  0.3056],
         [-0.0437,  0.4565],
         [ 0.0685,  0.1660],
         [-0.0395,  0.4477],
         [ 0.0294,  0.5441]]])

### Third way (one we will use)

In [18]:
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x 
xbow3

tensor([[[-2.0555,  1.8275],
         [-0.3760,  0.6887],
         [ 0.1984,  1.0228],
         [ 0.1177,  0.3465],
         [ 0.0888,  0.2920],
         [ 0.2493,  0.3563],
         [ 0.2575,  0.1987],
         [ 0.3182,  0.2848]],

        [[ 2.2874,  0.9611],
         [ 0.3789,  0.3350],
         [ 0.2146,  0.1187],
         [ 0.0036,  0.3737],
         [-0.1954,  0.3329],
         [ 0.0413,  0.2384],
         [-0.1156,  0.1108],
         [ 0.0977,  0.0096]],

        [[-0.8961,  0.0662],
         [-0.4762,  1.2037],
         [-1.2253,  0.9724],
         [-1.1226,  0.6678],
         [-0.8971,  0.9437],
         [-0.7739,  0.7500],
         [-0.8565,  0.6346],
         [-0.9811,  0.3822]],

        [[-0.3454, -1.1625],
         [-0.1005, -0.4981],
         [ 0.1833, -0.0277],
         [-0.2945,  0.3056],
         [-0.0437,  0.4565],
         [ 0.0685,  0.1660],
         [-0.0395,  0.4477],
         [ 0.0294,  0.5441]]])

### Single self-attention head

Note: self-attention means q, k, v all come from same source x.  
Cross attention means there's the keys and values come from a separate source

Because of softmax, if wei takes on very position or very negative numbers inside of it, softmax will  
converge to one-hot vectors. IOW: if the numbers inside a vector are more diffuse (evenly spread out),
softmax won't tend to converge toward a single value. If the numbers are very negative and positive during
initialization, then softmax will conversge to a single value, basically saying "only pay attention to this value".  
So the sqrt(d) scaling is used to control the variance
of softmax during initialization

This is why we scale by sqrt(d).
iow: the sqrt(d) scaling is used to control the variance at initialization

In [23]:
torch.manual_seed(1337)
B, T, C = 4, 8, 2
x = torch.randn(B, T, C)

# Single head
head_size = 16

# initialized weights between key and query will be different
# thus yeilding differeing values for each item in the batch
key = nn.Linear(C, head_size, bias=False) # inits with weights here
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False) 

k = key(x)   # (B, T, 16) - what do I contain?
q = query(x) # (B, T, 16) - what am I looking for?

# For each item in the batch, we'll have a TxT affinity matrix
wei = q @ k.transpose(-2, -1) * head_size**-0.5 # (B, T, 16) x (B, 16, T) --> (B, T, T)

tril = torch.tril(torch.ones(T, T))

# Mask so we don't use information from the future
# NOTE: only diff bw encoder and decoder
# is encoder has this line commented out
wei = wei.masked_fill(tril == 0, float('-inf')) 

wei = F.softmax(wei, dim=-1)

# Here is what I will communicate to you (between different heads. Gives head unique values).
v = value(x) # (B, T, 16)
out = wei @ v # (B, T, T) x (B, T, C) = (B, T, C), for dimension B is preserved
out

tensor([[[ 6.6016e-02,  8.6541e-02, -2.1800e-03, -9.7871e-02,  4.9378e-02,
          -8.4692e-02, -1.6165e-01, -4.9517e-02,  1.2838e-01,  1.3316e-01,
           9.1477e-03,  5.9705e-02,  1.5792e-01, -3.8152e-02,  4.1841e-02,
          -8.9396e-02],
         [-2.5548e-01,  1.1884e-01, -2.2966e-01, -1.9912e-01,  3.3471e-01,
           1.5141e-01, -2.4099e-01,  7.8147e-02,  2.9808e-02,  2.5287e-01,
           1.9010e-01, -9.2274e-02,  2.7042e-01, -6.0876e-02, -1.4815e-01,
          -2.5797e-01],
         [-2.7583e-02,  1.5441e-01, -9.9084e-02, -2.0180e-01,  2.0019e-01,
          -3.8674e-02, -2.9640e-01, -2.6971e-02,  1.6753e-01,  2.6698e-01,
           9.0885e-02,  3.3340e-02,  3.0425e-01, -7.1635e-02, -1.1698e-02,
          -2.1629e-01],
         [ 1.5503e-01,  2.0607e-01, -6.6096e-03, -2.3345e-01,  1.1925e-01,
          -1.9999e-01, -3.8504e-01, -1.1699e-01,  3.0476e-01,  3.1750e-01,
           2.2893e-02,  1.4108e-01,  3.7636e-01, -9.0899e-02,  9.8343e-02,
          -2.1371e-01],
    

In [21]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5599, 0.4401, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3220, 0.2016, 0.4764, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1640, 0.0815, 0.2961, 0.4585, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2051, 0.3007, 0.1894, 0.1808, 0.1241, 0.0000, 0.0000, 0.0000],
        [0.0600, 0.1273, 0.0291, 0.0169, 0.0552, 0.7114, 0.0000, 0.0000],
        [0.1408, 0.1025, 0.1744, 0.2038, 0.1690, 0.0669, 0.1426, 0.0000],
        [0.0223, 0.1086, 0.0082, 0.0040, 0.0080, 0.7257, 0.0216, 0.1016]],
       grad_fn=<SelectBackward0>)