### Single head attention block
(For code below, see from ~1:01:30 in video)

We consider this a <ins>decoder</ins> attention block specifically due to the mask preventing tokens from attending to future tokens within their sequence.

And specifically this is <ins>self-attention</ins> as the keys, queries, and values are ALL produced from the same set of input tokens.  Consider that the keys and values could instead be injected from an encoder, while the queries come from an earlier layer in the decoder (this is called <ins>cross-attention</ins>) (see 1:16:05 in video).

In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels

# think of this as 4 sepearate sequences, each of 8 chars (with chars represented as 32 channel vectors)
x = torch.randn(B,T,C)
#print("x:", x)

head_size = 16  # size of resulting key and query vectors
value_size = 16 # size of value representations created from each token

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, value_size, bias=False)

k = key(x) # (B,T,head_size)
q = query(x) # (B,T,head_size)
assert k.shape == q.shape

# now we multiply to get the dot product of respective key and query vectors for each sequence
tmp = k.transpose(-2, -1) # swaps dims -2 and -1 -> (B, head_size, T)
weights = q @ tmp # (B,T,head_size) @ (B,head_size,T) = (B, T, T)
# weights now specify for each sequence the cross attention affinities between all token combinations


# now we use a mask to prevent tokens from attending to future tokens in their sequence
tril = torch.tril(torch.ones(T, T))

# note: replacing zeros in triangular matrix with -inf and taking softmax
#  is equivalent to normalizing a lower triangular matrix of weights and zeros
#  we also divide by sqrt(head_size) to normalize the variance, preventing a tendency to one-hot vectors after softmax
weights = weights.masked_fill(tril == 0, float==('-inf')) * head_size**-0.5
weights = F.softmax(weights, dim=-1)

# now we let the tokens in each sequence combine according to their respective attention affinities
# in this case, the "values" would be the original tokens themselves
#out = weights @ x # (B,T,T) @ (B,T,C) = (B,T,C)
#assert out.shape == (B,T,C)

# ACTUALLy we let their values combine according to query:key affinities
v = weights @ value(x) # (B,T,T) @ (B, T, value_size) = (B,T,value_size)
assert v.shape == (B,T,value_size)
#weights[0]


In [5]:
torch.manual_seed(1337)
x = torch.randint(0, 10, (2,3,5))
print(x)
x.transpose(-2, -1)

tensor([[[5, 7, 2, 0, 5],
         [3, 5, 0, 4, 0],
         [2, 0, 7, 6, 0]],

        [[8, 1, 4, 9, 5],
         [3, 6, 2, 0, 2],
         [1, 6, 5, 9, 4]]])


tensor([[[5, 3, 2],
         [7, 5, 0],
         [2, 0, 7],
         [0, 4, 6],
         [5, 0, 0]],

        [[8, 3, 1],
         [1, 6, 6],
         [4, 2, 5],
         [9, 0, 9],
         [5, 2, 4]]])