In [5]:
# autoreload when imports change
%load_ext autoreload
%autoreload 2

In [54]:
import pathlib
from typing import Unpack
import math

In [7]:
from gpt_from_scratch import (
    file_utils,
    vocab_utils,
)

In [8]:
import torch

# imported for typechecking
#
# note: can't easily alias via jaxtyping annotations, as it's a string literal and
#       likely plays weirdly with typing.Annotation to forward a payload
# note: torchtyping is deprecated in favor of jaxtyping, as torchtyping doesn't have mypy integration
#
# note: jaxtyping does support prepending
#
#   Image = Float[Array, "channels height width"]
#   BatchImage = Float[Image, "batch"]
#
#    -->
#
#   BatchImage = Float[Array, "batch channels height width"]
#
# so we can compose aliases
#
from torch import Tensor
import jaxtyping
from jaxtyping import jaxtyped, Float32, Int64
from typeguard import typechecked as typechecker

# The mathematical trick of self attention

In [11]:
# important enough for Karpathy to call out

In [12]:
# problem:
# - we want it to be able to communicate previous token up to current token
torch.manual_seed(1337)

B, T, C = 4, 8, 2
x = torch.randn(B, T, C)
x.shape

torch.Size([4, 8, 2])

In [13]:
# - the simplest possible way, if I'm the 5th token, I just want
# -   [
#       <avg-0th-position>
#       <avg-1st-position>
#       <avg-2nd-position>
#       ...
#     ]
#   along the vocab dimension
#
# - this is "bag of words"
#
# this is our "version 1"

# We want x[b,t] = mean_{i<=t} x[b,i]
x_bow = torch.zeros((B,T,C))

# for each batch
for b in range(B):
    
    # for each token in the sequence
    for t in range(T):
        
        # slice up to (and including) the current token
        x_prev = x[b, :t+1] # (t,C)
        
        # average them (along the vocab dimension)
        x_bow[b, t] = torch.mean(x_prev, 0)

# now `x_bow` is average "up to "
x_bow.shape

In [25]:
print(f'First element:')
print(f' - {x[0][0]}')
print(f' -> {x_bow[0][0]}')

First element:
 - tensor([ 0.1808, -0.0700])
 -> tensor([ 0.1808, -0.0700])


In [26]:
print(f'Second element: (average of first two)')
print(f' - {x[0][0]}')
print(f' - {x[0][1]}')
print(f' -> {x_bow[0][1]}')

Second element: (average of first two)
 - tensor([ 0.1808, -0.0700])
 - tensor([-0.3596, -0.9152])
 -> tensor([-0.0894, -0.4926])


In [33]:
# this is good, but very inefficient
# the trick is that we can be very very good about doing this with matrix multiplication

# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)

# a = torch.ones(3, 3)

# OHHHH the triangular starts to look like the running total used in bag of words
a = torch.tril(torch.ones(3, 3))
print('a=')
print(a)
print('--')

# we can normalize them so they sum to 1 (now rows sum to 1)
a = a / torch.sum(a, 1, keepdim=True)

b = torch.randint(0, 10, (3, 2)).float()

c = a @ b

print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

# that was literally it lmao?

a=
tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])
--
a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [35]:
# "version 2": using matrix multiply for a weighted aggregation
#
# - this essentially lets:
#   - `a = weights`
#   - `b = x`
#
# - the weights are an intermediate computation tool
#
# - is self attention about letting them go in different directions?
#
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)

x_bow2 = weights @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)

# can see that same normalized running average pattern
weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])

In [36]:
# now we can show that this gives us the same thing our for loop was doing
torch.allclose(x_bow, x_bow2)

True

In [40]:
# version 3: use softmax
#
# - so why do we care? this is the same thing
# - this allows the weights to start at 0
# - and we can essentially treat them as "interaction strength"
import torch.nn.functional as F

tril = torch.tril(torch.ones(T, T))

# affinities aren't going to be constant at 0, we want to let them vary depending on the data
weights = torch.zeros((T,T))

# for all the weights that are 0, make them -inf
weights = weights.masked_fill(tril == 0, float('-inf'))

print('Before Softmax:')
print(weights)

weights = F.softmax(weights, dim=-1)

print('After Softmax:')
print(weights)

x_bow3 = weights @ x

Before Softmax:
tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
After Softmax:
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0

In [None]:
torch.allclose(x_bow, x_bow3)

In [None]:
# tldr: this is:
#  - weighted aggregation of information from previous tokens
#  - allows it to change

In [56]:
# version 4: self-attention!
#
# - attention for a single individual head
#
# note: Karpathy uses tokens <-> nodes interchangeably?
#
# - oh so Karpathy *literally* thinks of this as a graph
#
# - nodes with no notion of space
#
import torch.nn as nn

torch.manual_seed(1337)

B,T,C = 4, 8, 32 # batch, time, channels

x = torch.randn(B,T,C)

# let's see a single Head perform self-attention
head_size = 16

# we'll have every token "emit" a:
# - key:   "what do I contain"
# - query: "what am I looking for"
# - value: "if you find anything interesting to me, here's what I'll give you"
#
# - are these in terms of the updated weighted running average of previous tokens?
#
# - key * query = affinity (essentially acts like the weighted average matrix)
#
key   = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

# all tokens individually produce a key and a query vector
# - no communication has happened yet
k = key(x)   # (B, T, 16)
q = query(x) # (B, T, 16)

# - don't want these to actually be completely 0 and data independent
# weights = torch.zeros((T,T))
#
# note: batch dimension is only reason we can't just do `q @ k.T`
#
# (B, T, 16) @ (B, 16, T) ---> (B, T, T)
weights: Float32[Tensor, 'batch_size block_size block_size'] =  q @ k.transpose(-2, -1) 

print('\nExample: Batch 0 (before normalization by head size)\n')
print(weights[0])

# now `weights` has variance on the order of `n_heads`
# - (since both `q` and `k` have `unit_variance`)
# - want it to have `unit_variance` too since otherwise
# - `softmax`` will saturate really quickly
# - want these to be very diffuse at initialization, otherwise softmax will 
#   overly sharpen towards the max (from a single node)
# - that's fine if happens later but at initialization don't want that
weights = weights / math.sqrt(head_size)

# so now we have a `block_size x block_size` matrix for each batch
# essentially representing what was previously a weighted
# sum of previous tokens, but is now a learnable linear layer
#
# - so we've essentially created:
#  - a weighted average of previous tokens' contribution for each logit
#  - a pair of linear layers that let's us actually learn this via backprop
#  - they learn it in a data dependent way

# want to mask before aggregating, since we actually don't want to allow future tokens to propagate
print('\nExample: Batch 0 (before masking)\n')
print(weights[0])

tril = torch.tril(torch.ones(T, T))

weights = weights.masked_fill(tril == 0, float('-inf'))

print('\nExample: Batch 0 (before softmax)\n')
print(weights[0])

weights = F.softmax(weights, dim=-1)

print('\nExample: Batch 0 (after softmax)\n')
print(weights[0])


Example: Batch 0 (before normalization by head size)

tensor([[-1.7629, -1.3011,  0.5652,  2.1616, -1.0674,  1.9632,  1.0765, -0.4530],
        [-3.3334, -1.6556,  0.1040,  3.3782, -2.1825,  1.0415, -0.0557,  0.2927],
        [-1.0226, -1.2606,  0.0762, -0.3813, -0.9843, -1.4303,  0.0749, -0.9547],
        [ 0.7836, -0.8014, -0.3368, -0.8496, -0.5602, -1.1701, -1.2927, -1.0260],
        [-1.2566,  0.0187, -0.7880, -1.3204,  2.0363,  0.8638,  0.3719,  0.9258],
        [-0.3126,  2.4152, -0.1106, -0.9931,  3.3449, -2.5229,  1.4187,  1.2196],
        [ 1.0876,  1.9652, -0.2621, -0.3158,  0.6091,  1.2616, -0.5484,  0.8048],
        [-1.8044, -0.4126, -0.8306,  0.5898, -0.7987, -0.5856,  0.6433,  0.6303]],
       grad_fn=<SelectBackward0>)

Example: Batch 0 (before masking)

tensor([[-0.4407, -0.3253,  0.1413,  0.5404, -0.2668,  0.4908,  0.2691, -0.1132],
        [-0.8334, -0.4139,  0.0260,  0.8446, -0.5456,  0.2604, -0.0139,  0.0732],
        [-0.2557, -0.3152,  0.0191, -0.0953, -0.2461, 

In [51]:

# we let the actual aggregation itself be learned too because why not
#
# essentially when QK *does* match, here's how much I'll give you
#
# essentially how to interpret the affinities, since they don't have any intrinsic meaning as far as units
#
v = value(x) # (B, T, 16)

out = weights @ v # (B, T, T) @ (B, T, 16) ---> (B, T, 16)

out.shape

torch.Size([4, 8, 16])

In [52]:
out[0]

tensor([[-0.1571,  0.8801,  0.1615, -0.7824, -0.1429,  0.7468,  0.1007, -0.5239,
         -0.8873,  0.1907,  0.1762, -0.5943, -0.4812, -0.4860,  0.2862,  0.5710],
        [ 0.6764, -0.5477, -0.2478,  0.3143, -0.1280, -0.2952, -0.4296, -0.1089,
         -0.0493,  0.7268,  0.7130, -0.1164,  0.3266,  0.3431, -0.0710,  1.2716],
        [ 0.4823, -0.1069, -0.4055,  0.1770,  0.1581, -0.1697,  0.0162,  0.0215,
         -0.2490, -0.3773,  0.2787,  0.1629, -0.2895, -0.0676, -0.1416,  1.2194],
        [ 0.1971,  0.2856, -0.1303, -0.2655,  0.0668,  0.1954,  0.0281, -0.2451,
         -0.4647,  0.0693,  0.1528, -0.2032, -0.2479, -0.1621,  0.1947,  0.7678],
        [ 0.2510,  0.7346,  0.5939,  0.2516,  0.2606,  0.7582,  0.5595,  0.3539,
         -0.5934, -1.0807, -0.3111, -0.2781, -0.9054,  0.1318, -0.1382,  0.6371],
        [ 0.3428,  0.4960,  0.4725,  0.3028,  0.1844,  0.5814,  0.3824,  0.2952,
         -0.4897, -0.7705, -0.1172, -0.2541, -0.6892,  0.1979, -0.1513,  0.7666],
        [ 0.1866, -0.0

Notes:
- Attention is a **communication mechanism**. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.

- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.

- Each example across batch dimension is of course processed completely independently and never "talk" to each other

- In an $encoder$ attention block just delete the single line that does masking with `tril`, allowing all tokens to communicate. 
- This block here is called a $decoder$ attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.

- $self\_attention$ just means that the keys and values are produced from the same source as queries. 
- In $cross\_attention$, the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)

- "Scaled" attention additional divides `weights` by `1/sqrt(head_size)`. This makes it so when input `Q,K` are `unit variance`, `weights` will be unit variance too and `softmax` will stay diffuse and not saturate too much. Illustration below

In [45]:
out.shape

torch.Size([4, 8, 16])

In [None]:
# tldr: this is:
#  - weighted aggregation of information from previous tokens
#  - allows it to change