In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

B, T, C = 4, 8, 2  # Batch, time (tokens), channels (chars or embedding vector)
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [4]:
# We want to look at all tokens previous to the current token and
# including the current token.
# For now, we're just going to average the values for all
# the tokens.  It's not very smart, but okay to start.

# This set of loops is super slow and inefficient!
xavg = torch.zeros((B, T, C))
for batch in range(B):
    for token in range(T):
        xprev = x[batch, :token+1]  # -> [t, C]
        xavg[batch, token] = torch.mean(xprev, 0)

print(f"x[0]:\n{x[0]}")
print(f"x avg:\n{xavg[0]}")


x[0]:
tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
x avg:
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


In [8]:
# Matrix math can be fast on a GPU, so use matracies!

# We can have torch make us a triangle of numbers.  If they
# were all ones, then multiplying a matrix of numbers by the
# triangle of ones would give a sum of all previous items
# in the column.  If we also scale the ones down by what
# row they are on, then multiplying a matrix will end up
# taking an average of previous up to current element.
all_ones = torch.ones(3, 3)
a = torch.tril(all_ones)  # lower triangle
a = a / torch.sum(a, 1, keepdim=True) # Scale down by how many ones there are.
b = torch.tensor(((8,6), (6,4), (4,2))).float()
c = a @ b
print('a')
print(a)
print('b')
print(b)
print('c=')
print(c)

a
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
b
tensor([[8., 6.],
        [6., 4.],
        [4., 2.]])
c=
tensor([[8.0000, 6.0000],
        [7.0000, 5.0000],
        [6.0000, 4.0000]])


In [9]:
# Let's use the trick.
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)

#  (T, T) @ (B, T, C) --> pytorch will auto-scale wei by adding
# a leading B dimension.  So now (B, T, T) @ (B, T, C) --> (B, T, C)
xavg2 = wei @ x
                 
# This works because torch does a batch by batch matrix multiply of (T, T) x (T, C)
# just like the (3, 3) x (3, 2) above.
print(xavg2[0])
print("equal? ", torch.allclose(xavg, xavg2))


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])
equal?  True


In [10]:
# Yet another trick for the averages.  We can use softmax.  Softmax can look
# at a row of numbers and it gives the probability for a given number, in
# context to all the others.  For a simple row of numbers, it's just the average.
# Example: 1, 0, 0, 0, 0 ->  1.0,   0, 0, 0, 0
#          1, 1, 0, 0, 0 ->  0.5, 0.5, 0, 0, 0  etc
# Let's try it.
tril = torch.tril(torch.ones(T, T))
wei2 = torch.zeros((T, T))
wei2 = wei2.masked_fill(tril == 0, float('-inf'))
wei2 = F.softmax(wei2, dim=-1)
xavg3 = wei2 @ x
print("equal3?", torch.allclose(xavg, xavg3))

equal3? True


In [11]:
# Version 4: Self-attention.
# Here is where attention gets wild.  Instead of starting wei as zeros, we want to
# be smarter.  We will build 2 linear layers.  One called key, one called query.
# We have no idea what they will be, they need to be trained.  But they will be
# based on how many embeddings are possible.  Let's redefine them now.
B, T, C = 4, 8, 32  # So 32 embedding dimensionality.  4 batches of 8 tokens.
x = torch.randn(B, T, C)
tokens = torch.randint(0,1000,(8,))
print("New x", tokens)


New x tensor([641, 475, 438, 170, 611, 347, 184, 203])


In [18]:
# Make a single head for self-attention.
head_size = 16
key   = nn.Linear(C, head_size, bias=False) # map 32 inputs to 16 outputs using lots of neurons
query = nn.Linear(C, head_size, bias=False) # map 32 inputs to 16 outputs using lots of neurons
k = key(x)    # (B, T, 16) ie we did the neural net from 32 -> 16
q = query(x)  # (B, T, 16) same.

# At this point, q and k are totally independant values.  They were computed 
# using matrix multiplies on different random weights.  Eventually these weights
# will be determined through training.  Perhaps the key weights for the
# 'a' token will map to "I'm a vowel".  And the query weights for the 'a' token
# will map to "Looking for types of consanants".  But at this point they aren't
# comapred with any other tokens yet.  So next step is to align them.

# Those resulting outputs of the net are multiplied together.  This provides
# a cross-coupling between tokens.  Where there is high affinity between k and q
# we will find high weights, and where they are not well aligned, there will be 
# low values.  This is how dot product works.
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) --> (B, T, T)

# So wei is a much better weight matrix than if we had started with zeors and
# done averages.  The values in the matrix are based on how much each token
# pays "attention" to any other token in the T row.  
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))  # Now mask to hide the future tokens.
wei = F.softmax(wei, dim=-1)
# At this point wei is the weighting between tokens.
print("wei[0]:")
print(wei[0])
# Note that they are not uniform values to do an average any more.
print(f"wei shape: {wei.shape}")

out = wei @ x
print(f"out shape: {out.shape}")


wei[0]:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7552, 0.2448, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6585, 0.2948, 0.0467, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5087, 0.3908, 0.0152, 0.0854, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4088, 0.4924, 0.0443, 0.0049, 0.0495, 0.0000, 0.0000, 0.0000],
        [0.0061, 0.0300, 0.5775, 0.0287, 0.0047, 0.3530, 0.0000, 0.0000],
        [0.0276, 0.0179, 0.8380, 0.0565, 0.0049, 0.0467, 0.0083, 0.0000],
        [0.0143, 0.1380, 0.0144, 0.3662, 0.0086, 0.0912, 0.0092, 0.3581]],
       grad_fn=<SelectBackward0>)
wei shape: torch.Size([4, 8, 8])
out shape: torch.Size([4, 8, 32])


In [17]:
# Catch: we don't put the raw x in a multiply with the weights.  We compute
# a value for this that is the same size as our head:
value = nn.Linear(C, head_size, bias=False) # map 32 inputs to 16 outputs using lots of neurons
v = value(x)  # ie run the neural net from 32 -> 16

out = wei @ v
print(out.shape)  # Now output is the size of the attention head (B, T, 16)

torch.Size([4, 8, 16])


In [20]:
# Catch 2: The values we feed into softmax need to be fairly diffuse.  If they
# start to get large (ie far from the 0 to 1 range) then softmax will start
# to converge to just highlighting whichever value happens to be the biggest.  
# Example good vector:
test1 = torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]),   dim=-1)
test2 = torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1)
print(test1)
print(test2)
# Note how the second one is tending toward a one-hot encoding, which we don't want.
# We want the weights to stay closer to the 0 to 1 range.  So to fix this we are
# going to do 

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])
