In [5]:
# autoreload when imports change
%load_ext autoreload
%autoreload 2

In [6]:
import pathlib
from typing import Unpack

In [7]:
from gpt_from_scratch import (
    file_utils,
    vocab_utils,
)

In [8]:
import torch

# imported for typechecking
#
# note: can't easily alias via jaxtyping annotations, as it's a string literal and
#       likely plays weirdly with typing.Annotation to forward a payload
# note: torchtyping is deprecated in favor of jaxtyping, as torchtyping doesn't have mypy integration
#
# note: jaxtyping does support prepending
#
#   Image = Float[Array, "channels height width"]
#   BatchImage = Float[Image, "batch"]
#
#    -->
#
#   BatchImage = Float[Array, "batch channels height width"]
#
# so we can compose aliases
#
from torch import Tensor
import jaxtyping
from jaxtyping import jaxtyped, Float32, Int64
from typeguard import typechecked as typechecker

# The mathematical trick of self attention

In [11]:
# important enough for Karpathy to call out

In [12]:
# problem:
# - we want it to be able to communicate previous token up to current token
torch.manual_seed(1337)

B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [13]:
# - the simplest possible way, if I'm the 5th token, I just want
# -   [
#       <avg-0th-position>
#       <avg-1st-position>
#       <avg-2nd-position>
#       ...
#     ]
#   along the vocab dimension
#
# - this is "bag of words"
#
# this is our "version 1"

# We want x[b,t] = mean_{i<=t} x[b,i]
x_bow = torch.zeros((B,T,C))

# for each batch
for b in range(B):
    
    # for each token in the sequence
    for t in range(T):
        
        # slice up to (and including) the current token
        x_prev = x[b, :t+1] # (t,C)
        
        # average them (along the vocab dimension)
        x_bow[b, t] = torch.mean(x_prev, 0)

# now `x_bow` is average "up to "
x_bow.shape

In [25]:
print(f'First element:')
print(f' - {x[0][0]}')
print(f' -> {x_bow[0][0]}')

First element:
 - tensor([ 0.1808, -0.0700])
 -> tensor([ 0.1808, -0.0700])


In [26]:
print(f'Second element: (average of first two)')
print(f' - {x[0][0]}')
print(f' - {x[0][1]}')
print(f' -> {x_bow[0][1]}')

Second element: (average of first two)
 - tensor([ 0.1808, -0.0700])
 - tensor([-0.3596, -0.9152])
 -> tensor([-0.0894, -0.4926])


In [33]:
# this is good, but very inefficient
# the trick is that we can be very very good about doing this with matrix multiplication

# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)

# a = torch.ones(3, 3)

# OHHHH the triangular starts to look like the running total used in bag of words
a = torch.tril(torch.ones(3, 3))
print('a=')
print(a)
print('--')

# we can normalize them so they sum to 1 (now rows sum to 1)
a = a / torch.sum(a, 1, keepdim=True)

b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

# that was literally it lmao?

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
--
a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [35]:
# "version 2": using matrix multiply for a weighted aggregation
#
# - this essentially lets:
#   - `a = weights`
#   - `b = x`
#
# - the weights are an intermediate computation tool
#
# - is self attention about letting them go in different directions?
#
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)

x_bow2 = weights @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)

# can see that same normalized running average pattern
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [36]:
# now we can show that this gives us the same thing our for loop was doing
torch.allclose(x_bow, x_bow2)

True

In [40]:
# version 3: use softmax
#
# - so why do we care? this is the same thing
# - this allows the weights to start at 0
# - and we can essentially treat them as "interaction strength"
import torch.nn.functional as F

tril = torch.tril(torch.ones(T, T))

# affinities aren't going to be constant at 0, we want to let them vary depending on the data
weights = torch.zeros((T,T))

# for all the weights that are 0, make them -inf
weights = weights.masked_fill(tril == 0, float('-inf'))

print('Before Softmax:')
print(weights)

weights = F.softmax(weights, dim=-1)

print('After Softmax:')
print(weights)

x_bow3 = weights @ x

Before Softmax:
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
After Softmax:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0

In [None]:
torch.allclose(x_bow, x_bow3)

In [None]:
# tldr: this is:
#  - weighted aggregation of information from previous tokens
#  - allows it to change