# HackingGPT
## Part 6
Part 6 covers full self-attention with Query, Key, and Value projections, the complete attention mechanism with data-dependent weights, scaled attention to prevent peaky softmax, and layer normalization for training stability.

#### Author: [Kevin Thomas](mailto:ket189@pitt.edu)

In [144]:
import torch
import torch.nn as nn
from torch.nn import functional as F

## Step 1: Load and Inspect the Data
Now let's read the file and see what we're working with. Understanding your data is crucial before building any model!

In [145]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [146]:
text

'A dim glow rises behind the glass of a screen and the machine exhales in binary tides. The hum is a language and one who listens leans close to catch the quiet grammar. Patterns fold like small maps and seams hint at how the thing holds itself together. Treat each blinking diode and each idle tick as a sentence in a story that asks to be read.\n\nThere is patience here, not of haste but of careful unthreading. Where others see a sealed box the curious hand traces the join and wonders which thought made it fit. Do not rush to break, coax the meaning out with questions, and watch how the logic replies in traces and errors and in the echoes of forgotten interfaces.\n\nTechnology is artifact and argument at once. It makes a claim about what should be simple, what should be hidden, and what should be trusted. Reverse the gaze and learn its rhetoric, see where it promises ease, where it buries complexity, and where it leaves a backdoor as a sigh between bricks. To read that rhetoric is to b

## Step 2: Version 4 - FULL SELF-ATTENTION
Now we get to the real thing! In self-attention we have the following.

1. Each token produces a **Query** (Q): "What am I looking for?"
2. Each token produces a **Key** (K): "What do I contain?"
3. Each token produces a **Value** (V): "What information do I provide?"

The attention weights are computed as: **wei = Q @ K^T**
- High dot product equals tokens are "relevant" to each other.
- This is **data-dependent** where different inputs give different weights!

### Why Query, Key, Value?
Think of it like a search engine.
- **Query**: Your search terms (what you're looking for)
- **Key**: The titles/tags of documents (what each document contains)
- **Value**: The actual content of documents (what you get back)

The attention mechanism works in three steps.
1. Compute how well each query matches each key (dot product)
2. Normalize these scores to probabilities (softmax)
3. Use probabilities to weight the values (weighted sum)

### The Shape Journey
| Step | Tensor | Shape | Meaning |
|------|--------|-------|---------|
| Input | x | (B, T, C) | Batch of sequences with C features per token |
| Keys | k | (B, T, head_size) | Each token's "what I have" vector |
| Queries | q | (B, T, head_size) | Each token's "what I'm looking for" vector |
| Values | v | (B, T, head_size) | Each token's "what I'll give" vector |
| Raw Attention | wei | (B, T, T) | How much each position attends to each other position |
| Output | out | (B, T, head_size) | Weighted sum of values for each position |

In [147]:
torch.manual_seed(42)

<torch._C.Generator at 0x117c8e4f0>

In [148]:
# define batch dimension
B = 4  # batch size: 4 independent sequences
B

4

In [149]:
# define time dimension
T = 8  # sequence length: 8 tokens/positions in each sequence
T

8

In [150]:
# define channel dimension
C = 32  # feature size: 32 features per token
C

32

In [151]:
# start with random data
x = torch.randn(B, T, C)
x

tensor([[[ 1.9269,  1.4873,  0.9007,  ...,  0.0418, -0.2516,  0.8599],
         [-1.3847, -0.8712, -0.2234,  ...,  1.8446, -1.1845,  1.3835],
         [ 1.4451,  0.8564,  2.2181,  ..., -0.8278,  1.3347,  0.4835],
         ...,
         [-1.9006,  0.2286,  0.0249,  ..., -0.5558,  0.7043,  0.7099],
         [ 1.7744, -0.9216,  0.9624,  ..., -0.5003,  1.0350,  1.6896],
         [-0.0045,  1.6668,  0.1539,  ...,  0.5655,  0.5058,  0.2225]],

        [[-0.6855,  0.5636, -1.5072,  ...,  1.1566,  0.2691, -0.0366],
         [ 0.9733, -1.0151, -0.5419,  ..., -0.0553,  1.2049, -0.9825],
         [ 0.4334, -0.7172,  1.0554,  ..., -0.6766, -0.5730, -0.3303],
         ...,
         [ 0.6839, -1.3246, -0.5161,  ...,  1.1895,  0.7607, -0.7463],
         [-1.3839,  0.4869, -1.0020,  ...,  1.9535,  2.0487, -1.0880],
         [ 1.6217,  0.8513, -0.4005,  ...,  0.4232, -0.3389,  0.5180]],

        [[-1.3638,  0.1930, -0.6103,  ...,  0.6110,  1.2208, -0.6076],
         [-1.7376, -0.1254, -1.3658,  ..., -0

In [152]:
# head size: dimension of queries, keys, and values
head_size = 16
head_size

16

In [153]:
# learnable linear transformation, projecting input to key vectors; "What I have."
key = nn.Linear(C, head_size, bias=False)
key

Linear(in_features=32, out_features=16, bias=False)

In [154]:
# learnable linear transformation, projecting input to query vectors; "What I'm looking for?"
query = nn.Linear(C, head_size, bias=False)
query

Linear(in_features=32, out_features=16, bias=False)

In [155]:
# learnable linear transformation, projecting input to value vectors; "What I'll give if queried."
value = nn.Linear(C, head_size, bias=False)
value

Linear(in_features=32, out_features=16, bias=False)

In [156]:
# understand what nn.Linear does
print('understanding the linear projections')
print()
print('nn.Linear(C, head_size, bias=False) creates a matrix of shape (C, head_size)')
print(f'   C = {C} (input features)')
print(f'   head_size = {head_size} (output features)')
print()
print('when we call key(x), it computes: x @ key.weight.T')
print(f'   x shape: {x.shape} = (B={B}, T={T}, C={C})')
print(f'   key.weight shape: {key.weight.shape} = (head_size={head_size}, C={C})')
print(f'   key.weight.T shape: ({C}, {head_size})')
print()
print('matrix multiplication')
print(f'   (B, T, C) @ (C, head_size) = (B, T, head_size)')
print(f'   ({B}, {T}, {C}) @ ({C}, {head_size}) = ({B}, {T}, {head_size})')
print()
print('each token\'s C-dimensional vector gets projected to head_size dimensions')
print('these projections are LEARNED during training')

understanding the linear projections

nn.Linear(C, head_size, bias=False) creates a matrix of shape (C, head_size)
   C = 32 (input features)
   head_size = 16 (output features)

when we call key(x), it computes: x @ key.weight.T
   x shape: torch.Size([4, 8, 32]) = (B=4, T=8, C=32)
   key.weight shape: torch.Size([16, 32]) = (head_size=16, C=32)
   key.weight.T shape: (32, 16)

matrix multiplication
   (B, T, C) @ (C, head_size) = (B, T, head_size)
   (4, 8, 32) @ (32, 16) = (4, 8, 16)

each token's C-dimensional vector gets projected to head_size dimensions
these projections are LEARNED during training


### Step 1: Compute Keys

In [157]:
# compute keys
k = key(x)  # (B, T, head_size) = (4, 8, 16)
k

tensor([[[ 7.0784e-02, -9.4861e-01, -5.9983e-01, -8.8679e-01,  4.7325e-02,
          -5.0741e-03, -6.7452e-03,  6.4850e-01, -3.2939e-01,  6.5462e-01,
          -3.5305e-01,  3.8077e-01,  3.3350e-01, -1.9763e-01, -1.5752e-01,
          -3.8165e-01],
         [-1.1677e+00,  2.7538e-01,  1.6652e+00, -2.7140e-01,  1.4043e-01,
           2.7449e-03, -1.0794e+00,  2.6188e-01, -1.1814e-01, -5.4476e-01,
           7.9574e-02, -3.6371e-02,  8.4531e-01,  8.1885e-01,  2.1071e-01,
          -5.8136e-01],
         [ 4.3110e-01, -6.2462e-01, -1.8344e-01, -2.9284e-01,  2.5957e-01,
          -5.1398e-01, -7.4316e-01,  1.2174e-01, -1.1386e+00,  4.8304e-01,
          -1.7443e-02,  6.1590e-01, -5.5573e-02, -1.0868e+00, -1.1277e+00,
           8.2533e-02],
         [-1.5513e-01,  3.9932e-01,  6.9395e-01, -1.9858e-01,  9.4391e-02,
          -1.5222e-01, -5.7460e-01,  5.1392e-01,  8.2661e-01,  1.7277e-02,
           3.8279e-01, -9.2914e-01,  3.2609e-01,  1.7348e-01,  1.0189e-01,
          -2.2036e-01],
    

In [158]:
# k shape
k.shape

torch.Size([4, 8, 16])

In [159]:
# understand the keys tensor
print('understanding the keys (k)')
print()
print(f'k shape: {k.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does each dimension mean?')
print(f'   B={B}: we have {B} independent sequences (batch)')
print(f'   T={T}: each sequence has {T} tokens/positions')
print(f'   head_size={head_size}: each key vector has {head_size} dimensions')
print()
print('for batch 0, position 0')
print(f'   k[0, 0] = {k[0, 0].tolist()[:4]}... (first 4 of {head_size} values)')
print('   this is position 0\'s "what I contain" vector')
print()
print('for batch 0, position 1')
print(f'   k[0, 1] = {k[0, 1].tolist()[:4]}... (first 4 of {head_size} values)')
print('   this is position 1\'s "what I contain" vector')
print()
print('each position has its own unique key vector')
print('keys tell us what information each token has to offer')

understanding the keys (k)

k shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)

what does each dimension mean?
   B=4: we have 4 independent sequences (batch)
   T=8: each sequence has 8 tokens/positions
   head_size=16: each key vector has 16 dimensions

for batch 0, position 0
   k[0, 0] = [0.07078436762094498, -0.9486113786697388, -0.5998315811157227, -0.8867908120155334]... (first 4 of 16 values)
   this is position 0's "what I contain" vector

for batch 0, position 1
   k[0, 1] = [-1.1676945686340332, 0.27537861466407776, 1.6651532649993896, -0.271398663520813]... (first 4 of 16 values)
   this is position 1's "what I contain" vector

each position has its own unique key vector
keys tell us what information each token has to offer


### Step 2: Compute Queries

In [160]:
# computer queries
q = query(x)  # (B, T, head_size) = (4, 8, 16)
q

tensor([[[-0.4467,  0.6160, -0.5806, -0.4032, -1.2520, -0.3569,  0.3134,
          -0.7361,  0.9993,  0.0667, -0.5147,  0.0405,  0.2954, -0.2314,
          -0.5547,  0.0771],
         [-0.5397,  0.7083, -0.7480, -0.7873,  0.8294, -0.5856,  0.1124,
          -0.3189,  0.3401, -0.4997,  0.9458, -0.9430,  0.4883,  0.3419,
          -0.2482, -0.1744],
         [-0.6314,  0.3053, -0.8942, -0.2582, -1.1889, -0.5720, -0.1693,
          -0.6428,  0.3851, -0.3830, -0.7606,  0.9779, -0.4845, -0.7168,
          -0.1794, -0.3390],
         [-0.4691, -0.1064,  0.0868, -0.1737,  0.3330,  0.1009,  0.0584,
          -0.0182,  0.1981, -0.0587,  0.4190, -0.6249,  0.1330,  0.0505,
          -0.7110, -0.1839],
         [ 0.9760,  0.0929, -0.5380, -0.2953, -0.6723,  0.3561,  0.2937,
          -0.0166, -0.4360, -0.4318, -1.0178, -1.0906, -0.0714,  0.4877,
           0.9537,  0.1052],
         [-0.2460,  0.1963, -0.0330, -0.0296,  0.4262,  0.5681, -0.5607,
           0.3858,  0.6359, -0.1462,  0.9695,  0.516

In [161]:
# query shape
q.shape

torch.Size([4, 8, 16])

In [162]:
# understand the queries tensor
print('understanding the queries (q)')
print()
print(f'q shape: {q.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does each dimension mean?')
print(f'   B={B}: we have {B} independent sequences (batch)')
print(f'   T={T}: each sequence has {T} tokens/positions')
print(f'   head_size={head_size}: each query vector has {head_size} dimensions')
print()
print('for batch 0, position 0')
print(f'   q[0, 0] = {q[0, 0].tolist()[:4]}... (first 4 of {head_size} values)')
print('   this is position 0\'s "what I\'m looking for" vector')
print()
print('for batch 0, position 1')
print(f'   q[0, 1] = {q[0, 1].tolist()[:4]}... (first 4 of {head_size} values)')
print('   this is position 1\'s "what I\'m looking for" vector')
print()
print('each position has its own unique query vector')
print('queries tell us what information each token is searching for')

understanding the queries (q)

q shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)

what does each dimension mean?
   B=4: we have 4 independent sequences (batch)
   T=8: each sequence has 8 tokens/positions
   head_size=16: each query vector has 16 dimensions

for batch 0, position 0
   q[0, 0] = [-0.44672077894210815, 0.6160160899162292, -0.580593466758728, -0.4032447636127472]... (first 4 of 16 values)
   this is position 0's "what I'm looking for" vector

for batch 0, position 1
   q[0, 1] = [-0.5397410988807678, 0.7082756161689758, -0.7479986548423767, -0.7873462438583374]... (first 4 of 16 values)
   this is position 1's "what I'm looking for" vector

each position has its own unique query vector
queries tell us what information each token is searching for


### Step 3: Compute Raw Attention Scores

In [163]:
# compute raw attention scores (affinities)
# wei[b, i, j] = how much position i attends to position j
# computed as dot product of query[i] and key[j]
wei = q @ k.transpose(-2, -1)  # (B, T, 16) @ (B, 16, T) → (B, T, T)
wei

tensor([[[-0.3332, -1.1723, -1.0216, -0.0545, -1.0950,  0.2735,  0.1340,
          -0.8490],
         [-0.6597,  0.7869, -1.2725,  1.6851,  0.1159,  0.5450,  0.2356,
          -0.1962],
         [ 0.3630, -1.5219,  0.7821, -1.7215, -0.3494,  0.2884, -0.1021,
          -1.4271],
         [-0.1001,  0.8649, -0.0335,  1.0221, -0.1350, -0.3078,  0.1440,
          -0.3019],
         [ 0.0136, -1.6202, -1.9888, -0.3327, -1.2506, -0.8928, -2.2674,
           3.0561],
         [-0.5833,  1.2025, -0.3281,  0.9147,  0.9809, -0.4859,  1.7589,
           0.1650],
         [ 1.1351, -1.9940,  1.5545, -1.8037, -0.5062, -2.6109, -1.0739,
           1.6430],
         [-1.2784, -0.4554, -1.4118,  0.6392, -0.5780,  1.9291,  1.6689,
           0.1103]],

        [[ 0.0685,  0.8637, -0.6632, -0.6766,  0.3320,  0.5343, -0.1189,
          -0.4954],
         [-0.3457, -0.6529,  1.8824,  2.0650, -0.0657, -0.6715,  1.4419,
          -0.1583],
         [-0.0532, -2.0680,  0.5618, -1.8266,  0.9683, -0.3595, -0.1

In [164]:
# understand the attention score computation
print('understanding q @ k.transpose(-2, -1)')
print()
print(f'q shape: {q.shape} = (B, T, head_size) = ({B}, {T}, {head_size})')
print(f'k shape: {k.shape} = (B, T, head_size) = ({B}, {T}, {head_size})')
print()
print('k.transpose(-2, -1) swaps the last two dimensions')
print(f'   before: k shape = ({B}, {T}, {head_size})')
print(f'   after:  k.T shape = ({B}, {head_size}, {T})')
print()
print('the matrix multiplication')
print(f'   (B, T, head_size) @ (B, head_size, T) = (B, T, T)')
print(f'   ({B}, {T}, {head_size}) @ ({B}, {head_size}, {T}) = ({B}, {T}, {T})')
print()
print('what does wei[b, i, j] mean?')
print('   wei[b, i, j] = dot product of query[i] and key[j] in batch b')
print('   higher value = position i is more interested in position j')
print('   lower value = position i is less interested in position j')
print()
print(f'wei shape: {wei.shape}')

understanding q @ k.transpose(-2, -1)

q shape: torch.Size([4, 8, 16]) = (B, T, head_size) = (4, 8, 16)
k shape: torch.Size([4, 8, 16]) = (B, T, head_size) = (4, 8, 16)

k.transpose(-2, -1) swaps the last two dimensions
   before: k shape = (4, 8, 16)
   after:  k.T shape = (4, 16, 8)

the matrix multiplication
   (B, T, head_size) @ (B, head_size, T) = (B, T, T)
   (4, 8, 16) @ (4, 16, 8) = (4, 8, 8)

what does wei[b, i, j] mean?
   wei[b, i, j] = dot product of query[i] and key[j] in batch b
   higher value = position i is more interested in position j
   lower value = position i is less interested in position j

wei shape: torch.Size([4, 8, 8])


In [165]:
# trace through one attention score calculation manually
print('tracing through one attention score calculation')
print()
print('let\'s compute wei[0, 0, 1] manually')
print('this is: how much does position 0 attend to position 1 (in batch 0)?')
print()
print(f'q[0, 0] (query at position 0):')
print(f'   {q[0, 0].tolist()[:4]}... (first 4 values)')
print()
print(f'k[0, 1] (key at position 1):')
print(f'   {k[0, 1].tolist()[:4]}... (first 4 values)')
print()
print('dot product = sum of element-wise products')
manual_dot = (q[0, 0] * k[0, 1]).sum().item()
print(f'   q[0, 0] · k[0, 1] = {manual_dot:.4f}')
print()
print(f'actual wei[0, 0, 1] = {wei[0, 0, 1].item():.4f}')
print(f'match: {abs(manual_dot - wei[0, 0, 1].item()) < 1e-5}')

tracing through one attention score calculation

let's compute wei[0, 0, 1] manually
this is: how much does position 0 attend to position 1 (in batch 0)?

q[0, 0] (query at position 0):
   [-0.44672077894210815, 0.6160160899162292, -0.580593466758728, -0.4032447636127472]... (first 4 values)

k[0, 1] (key at position 1):
   [-1.1676945686340332, 0.27537861466407776, 1.6651532649993896, -0.271398663520813]... (first 4 values)

dot product = sum of element-wise products
   q[0, 0] · k[0, 1] = -1.1723

actual wei[0, 0, 1] = -1.1723
match: True


In [166]:
# examine one row of attention scores (before masking)
print('examining raw attention scores for position 0 (batch 0)')
print()
print(f'wei[0, 0] = attention scores from position 0 to all positions')
print(f'   {wei[0, 0].tolist()}')
print()
print('interpreting these scores')
for j in range(T):
    score = wei[0, 0, j].item()
    if score > 0:
        print(f'   position 0 → position {j}: {score:.4f} (positive = interested)')
    else:
        print(f'   position 0 → position {j}: {score:.4f} (negative = not interested)')
print()
print('note: these are RAW scores, not probabilities yet')
print('we still need to apply masking and softmax')

examining raw attention scores for position 0 (batch 0)

wei[0, 0] = attention scores from position 0 to all positions
   [-0.33323800563812256, -1.1722666025161743, -1.0215535163879395, -0.05452167987823486, -1.0950181484222412, 0.27353763580322266, 0.13403844833374023, -0.8489717841148376]

interpreting these scores
   position 0 → position 0: -0.3332 (negative = not interested)
   position 0 → position 1: -1.1723 (negative = not interested)
   position 0 → position 2: -1.0216 (negative = not interested)
   position 0 → position 3: -0.0545 (negative = not interested)
   position 0 → position 4: -1.0950 (negative = not interested)
   position 0 → position 5: 0.2735 (positive = interested)
   position 0 → position 6: 0.1340 (positive = interested)
   position 0 → position 7: -0.8490 (negative = not interested)

note: these are RAW scores, not probabilities yet
we still need to apply masking and softmax


### Step 4: Mask Future Positions

In [167]:
# create lower-triangular mask
tril = torch.tril(torch.ones(T, T))
tril

tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

In [168]:
# understand the lower triangular mask
print('understanding the lower triangular mask')
print()
print('tril = torch.tril(torch.ones(T, T))')
print(f'T = {T}, so we create an {T}x{T} matrix')
print()
print('examining each row of tril')
for i in range(T):
    row = tril[i].tolist()
    visible = [j for j, v in enumerate(row) if v == 1.0]
    print(f'row {i}: {[int(v) for v in row]}')
    print(f'       position {i} can see positions {visible}')
    print()

understanding the lower triangular mask

tril = torch.tril(torch.ones(T, T))
T = 8, so we create an 8x8 matrix

examining each row of tril
row 0: [1, 0, 0, 0, 0, 0, 0, 0]
       position 0 can see positions [0]

row 1: [1, 1, 0, 0, 0, 0, 0, 0]
       position 1 can see positions [0, 1]

row 2: [1, 1, 1, 0, 0, 0, 0, 0]
       position 2 can see positions [0, 1, 2]

row 3: [1, 1, 1, 1, 0, 0, 0, 0]
       position 3 can see positions [0, 1, 2, 3]

row 4: [1, 1, 1, 1, 1, 0, 0, 0]
       position 4 can see positions [0, 1, 2, 3, 4]

row 5: [1, 1, 1, 1, 1, 1, 0, 0]
       position 5 can see positions [0, 1, 2, 3, 4, 5]

row 6: [1, 1, 1, 1, 1, 1, 1, 0]
       position 6 can see positions [0, 1, 2, 3, 4, 5, 6]

row 7: [1, 1, 1, 1, 1, 1, 1, 1]
       position 7 can see positions [0, 1, 2, 3, 4, 5, 6, 7]



In [169]:
# apply the mask to the attention scores
wei = wei.masked_fill(tril == 0, float('-inf'))
wei

tensor([[[-0.3332,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [-0.6597,  0.7869,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [ 0.3630, -1.5219,  0.7821,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [-0.1001,  0.8649, -0.0335,  1.0221,    -inf,    -inf,    -inf,
             -inf],
         [ 0.0136, -1.6202, -1.9888, -0.3327, -1.2506,    -inf,    -inf,
             -inf],
         [-0.5833,  1.2025, -0.3281,  0.9147,  0.9809, -0.4859,    -inf,
             -inf],
         [ 1.1351, -1.9940,  1.5545, -1.8037, -0.5062, -2.6109, -1.0739,
             -inf],
         [-1.2784, -0.4554, -1.4118,  0.6392, -0.5780,  1.9291,  1.6689,
           0.1103]],

        [[ 0.0685,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [-0.3457, -0.6529,    -inf,    -inf,    -inf,    -inf,    -inf,
             -inf],
         [-0.0532, -2.0680,  0.5618,    -inf,    -inf,    -inf,    -

In [170]:
# understand the masked_fill operation
print('understanding masked_fill with -inf')
print()
print('wei = wei.masked_fill(tril == 0, float(\'-inf\'))')
print()
print('step 1: tril == 0 creates a boolean mask')
print('        True where we CANNOT look (future positions)')
print('        False where we CAN look (current and past)')
print()
print('step 2: masked_fill replaces True positions with -inf')
print()
print('examining batch 0, position 0')
print(f'   before masking: wei[0, 0] could see all positions')
print(f'   after masking:  wei[0, 0] = {wei[0, 0].tolist()}')
print('   positions 1-7 are now -inf (cannot look at future)')
print()
print('examining batch 0, position 3')
print(f'   wei[0, 3] = {wei[0, 3].tolist()}')
print('   positions 0-3 have real values, positions 4-7 are -inf')
print()
print('why -inf?')
print('   because e^(-inf) = 0')
print('   when we apply softmax, -inf values become 0 probability')
print('   this completely blocks information from future positions')

understanding masked_fill with -inf

wei = wei.masked_fill(tril == 0, float('-inf'))

step 1: tril == 0 creates a boolean mask
        True where we CANNOT look (future positions)
        False where we CAN look (current and past)

step 2: masked_fill replaces True positions with -inf

examining batch 0, position 0
   before masking: wei[0, 0] could see all positions
   after masking:  wei[0, 0] = [-0.33323800563812256, -inf, -inf, -inf, -inf, -inf, -inf, -inf]
   positions 1-7 are now -inf (cannot look at future)

examining batch 0, position 3
   wei[0, 3] = [-0.10006709396839142, 0.8648553490638733, -0.033506155014038086, 1.0221478939056396, -inf, -inf, -inf, -inf]
   positions 0-3 have real values, positions 4-7 are -inf

why -inf?
   because e^(-inf) = 0
   when we apply softmax, -inf values become 0 probability
   this completely blocks information from future positions


### Step 5: Softmax to Get Probabilities

In [171]:
# normalize the attention scores to probabilities
wei = F.softmax(wei, dim=-1)
wei

tensor([[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.9052e-01, 8.0948e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [3.7418e-01, 5.6820e-02, 5.6900e-01, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [1.2878e-01, 3.3800e-01, 1.3765e-01, 3.9557e-01, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [4.3106e-01, 8.4134e-02, 5.8193e-02, 3.0487e-01, 1.2175e-01,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [5.3737e-02, 3.2051e-01, 6.9361e-02, 2.4035e-01, 2.5680e-01,
          5.9238e-02, 0.0000e+00, 0.0000e+00],
         [3.3957e-01, 1.4859e-02, 5.1650e-01, 1.7974e-02, 6.5784e-02,
          8.0182e-03, 3.7289e-02, 0.0000e+00],
         [1.6459e-02, 3.7486e-02, 1.4404e-02, 1.1200e-01, 3.3159e-02,
          4.0686e-01, 3.1364e-01, 6.5997e-02]],

        [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.00

In [172]:
# understand softmax normalization
print('understanding softmax normalization')
print()
print('wei = F.softmax(wei, dim=-1)')
print()
print('softmax converts raw scores to probabilities')
print('   - all values become positive (e^x > 0)')
print('   - all values in a row sum to 1')
print('   - e^(-inf) = 0 (masked positions become 0)')
print()
print('examining batch 0, position 0')
print(f'   wei[0, 0] = {[round(v, 4) for v in wei[0, 0].tolist()]}')
print(f'   sum = {wei[0, 0].sum().item():.4f}')
print('   only position 0 is visible, so it gets weight 1.0')
print()
print('examining batch 0, position 3')
print(f'   wei[0, 3] = {[round(v, 4) for v in wei[0, 3].tolist()]}')
print(f'   sum = {wei[0, 3].sum().item():.4f}')
print('   positions 0-3 have non-zero weights that sum to 1')
print()
print('this is DATA-DEPENDENT attention!')
print('the weights are NOT uniform (like 0.25, 0.25, 0.25, 0.25)')
print('instead, they depend on the query-key dot products')

understanding softmax normalization

wei = F.softmax(wei, dim=-1)

softmax converts raw scores to probabilities
   - all values become positive (e^x > 0)
   - all values in a row sum to 1
   - e^(-inf) = 0 (masked positions become 0)

examining batch 0, position 0
   wei[0, 0] = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
   sum = 1.0000
   only position 0 is visible, so it gets weight 1.0

examining batch 0, position 3
   wei[0, 3] = [0.1288, 0.338, 0.1376, 0.3956, 0.0, 0.0, 0.0, 0.0]
   sum = 1.0000
   positions 0-3 have non-zero weights that sum to 1

this is DATA-DEPENDENT attention!
the weights are NOT uniform (like 0.25, 0.25, 0.25, 0.25)
instead, they depend on the query-key dot products


### Step 6: Compute Values & Weighted Sum

In [173]:
# compute values
v = value(x)  # (B, T, head_size)
v

tensor([[[ 0.7630, -0.2412, -0.4150,  0.3833,  0.5740, -1.6738,  0.7954,
           0.6872, -0.3848,  0.5073, -0.5312, -0.1221,  0.0445,  1.2169,
           0.9940,  1.5281],
         [ 0.3218, -0.0569, -0.8477, -0.7261,  0.0893, -0.1100, -0.0939,
          -1.0305,  0.0200,  0.2691,  0.5359,  0.1426,  0.1681,  0.3577,
           0.2332, -0.5051],
         [-0.1803,  0.2362,  0.1637,  0.5017,  0.7742, -0.4373, -0.1290,
          -0.2560, -0.4637,  0.2134,  0.0475, -0.3715, -0.8919,  0.1818,
           1.0010,  0.8075],
         [-0.3670,  0.6613,  0.3709,  0.2785, -0.5886, -0.1584, -0.5308,
          -0.9533, -0.2020,  0.1104,  0.0643,  0.5399,  0.8794,  0.3147,
          -0.3706, -0.2829],
         [-0.2686,  0.1055, -0.7112,  1.0555,  0.3917, -0.2944, -1.1679,
           0.3537,  0.7702, -0.4535, -0.7852, -0.2078, -0.5898, -0.0895,
           0.1527, -0.2386],
         [-0.3207, -0.1223,  0.2610, -0.5424, -0.0771, -0.4658,  0.2526,
          -0.0036, -0.5229,  0.1433,  0.9100,  0.299

In [174]:
# understand values
print('understanding values (v)')
print()
print('v = value(x) projects each token to a "value" representation')
print(f'v shape: {v.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does each value vector represent?')
print('   it\'s the "content" that each token will contribute')
print('   if position j gets high attention weight, its value contributes more')
print()
print('for batch 0, position 0')
print(f'   v[0, 0] = {v[0, 0].tolist()[:4]}... (first 4 of {head_size} values)')
print('   this is what position 0 will "give" when attended to')
print()
print('for batch 0, position 1')
print(f'   v[0, 1] = {v[0, 1].tolist()[:4]}... (first 4 of {head_size} values)')
print('   this is what position 1 will "give" when attended to')

understanding values (v)

v = value(x) projects each token to a "value" representation
v shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)

what does each value vector represent?
   it's the "content" that each token will contribute
   if position j gets high attention weight, its value contributes more

for batch 0, position 0
   v[0, 0] = [0.7629690766334534, -0.24118372797966003, -0.4150242507457733, 0.3832956552505493]... (first 4 of 16 values)
   this is what position 0 will "give" when attended to

for batch 0, position 1
   v[0, 1] = [0.3217601180076599, -0.0569150447845459, -0.8477029800415039, -0.7260561585426331]... (first 4 of 16 values)
   this is what position 1 will "give" when attended to


In [175]:
# compute the weighted sum of the values
out = wei @ v  # (B, T, T) @ (B, T, 16) → (B, T, 16)
out

tensor([[[ 7.6297e-01, -2.4118e-01, -4.1502e-01,  3.8330e-01,  5.7404e-01,
          -1.6738e+00,  7.9543e-01,  6.8724e-01, -3.8477e-01,  5.0733e-01,
          -5.3124e-01, -1.2214e-01,  4.4479e-02,  1.2169e+00,  9.9396e-01,
           1.5281e+00],
         [ 4.0582e-01, -9.2022e-02, -7.6527e-01, -5.1470e-01,  1.8168e-01,
          -4.0795e-01,  7.5564e-02, -7.0327e-01, -5.7126e-02,  3.1450e-01,
           3.3258e-01,  9.2198e-02,  1.4456e-01,  5.2138e-01,  3.7812e-01,
          -1.1776e-01],
         [ 2.0116e-01,  4.0931e-02, -1.1034e-01,  3.8763e-01,  6.6036e-01,
          -8.8138e-01,  2.1889e-01,  5.2931e-02, -4.0671e-01,  3.2654e-01,
          -1.4131e-01, -2.4897e-01, -4.8128e-01,  5.7909e-01,  9.5475e-01,
           1.0026e+00],
         [ 3.7002e-02,  2.4381e-01, -1.7073e-01, -1.6806e-02, -2.2144e-02,
          -3.7560e-01, -1.5701e-01, -6.7214e-01, -1.8652e-01,  2.2933e-01,
           1.4469e-01,  1.9491e-01,  2.8767e-01,  4.2711e-01,  1.9802e-01,
           2.5314e-02],
    

In [176]:
# understand the final output
print('understanding the weighted sum: out = wei @ v')
print()
print(f'wei shape: {wei.shape} = (B={B}, T={T}, T={T})')
print(f'v shape:   {v.shape} = (B={B}, T={T}, head_size={head_size})')
print(f'out shape: {out.shape} = (B={B}, T={T}, head_size={head_size})')
print()
print('what does out[b, i] represent?')
print('   it\'s a weighted average of all value vectors')
print('   weighted by how much position i attends to each position')
print()
print('for position 0 (can only see itself)')
print(f'   wei[0, 0] = {[round(v, 3) for v in wei[0, 0].tolist()]}')
print('   out[0, 0] = 1.0 * v[0, 0] + 0.0 * v[0, 1] + ... + 0.0 * v[0, 7]')
print('   out[0, 0] ≈ v[0, 0] (just itself)')
print()
print('for position 3 (can see positions 0, 1, 2, 3)')
weights_3 = [round(w, 3) for w in wei[0, 3].tolist()]
print(f'   wei[0, 3] = {weights_3}')
print(f'   out[0, 3] = {weights_3[0]} * v[0, 0] + {weights_3[1]} * v[0, 1] + {weights_3[2]} * v[0, 2] + {weights_3[3]} * v[0, 3]')
print('   out[0, 3] is a weighted mix of values from positions 0-3')

understanding the weighted sum: out = wei @ v

wei shape: torch.Size([4, 8, 8]) = (B=4, T=8, T=8)
v shape:   torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)
out shape: torch.Size([4, 8, 16]) = (B=4, T=8, head_size=16)

what does out[b, i] represent?
   it's a weighted average of all value vectors
   weighted by how much position i attends to each position

for position 0 (can only see itself)
   wei[0, 0] = [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
   out[0, 0] = 1.0 * v[0, 0] + 0.0 * v[0, 1] + ... + 0.0 * v[0, 7]
   out[0, 0] ≈ v[0, 0] (just itself)

for position 3 (can see positions 0, 1, 2, 3)
   wei[0, 3] = [0.129, 0.338, 0.138, 0.396, 0.0, 0.0, 0.0, 0.0]
   out[0, 3] = 0.129 * v[0, 0] + 0.338 * v[0, 1] + 0.138 * v[0, 2] + 0.396 * v[0, 3]
   out[0, 3] is a weighted mix of values from positions 0-3


In [177]:
# manually verify one output calculation
print('manually verifying out[0, 0]')
print()
print('out[0, 0] should equal v[0, 0] since position 0 only sees itself')
print()
print(f'out[0, 0] (first 4 values): {out[0, 0].tolist()[:4]}')
print(f'v[0, 0] (first 4 values):   {v[0, 0].tolist()[:4]}')
print()
print(f'are they equal? {torch.allclose(out[0, 0], v[0, 0])}')

manually verifying out[0, 0]

out[0, 0] should equal v[0, 0] since position 0 only sees itself

out[0, 0] (first 4 values): [0.7629690766334534, -0.24118372797966003, -0.4150242507457733, 0.3832956552505493]
v[0, 0] (first 4 values):   [0.7629690766334534, -0.24118372797966003, -0.4150242507457733, 0.3832956552505493]

are they equal? True


In [178]:
# manually verify position 1 calculation
print('manually verifying out[0, 1]')
print()
print('position 1 can see positions 0 and 1')
print(f'wei[0, 1] = {[round(w, 4) for w in wei[0, 1].tolist()]}')
print()
print('computing the weighted sum manually')
w0 = wei[0, 1, 0].item()
w1 = wei[0, 1, 1].item()
print(f'   weight for position 0: {w0:.4f}')
print(f'   weight for position 1: {w1:.4f}')
print()
manual_out = w0 * v[0, 0] + w1 * v[0, 1]
print(f'manual calculation:')
print(f'   {w0:.4f} * v[0, 0] + {w1:.4f} * v[0, 1]')
print(f'   = {manual_out.tolist()[:4]}... (first 4 values)')
print()
print(f'actual out[0, 1]:')
print(f'   = {out[0, 1].tolist()[:4]}... (first 4 values)')
print()
print(f'are they equal? {torch.allclose(manual_out, out[0, 1])}')

manually verifying out[0, 1]

position 1 can see positions 0 and 1
wei[0, 1] = [0.1905, 0.8095, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

computing the weighted sum manually
   weight for position 0: 0.1905
   weight for position 1: 0.8095

manual calculation:
   0.1905 * v[0, 0] + 0.8095 * v[0, 1]
   = [0.4058188199996948, -0.09202173352241516, -0.7652695178985596, -0.5147035717964172]... (first 4 values)

actual out[0, 1]:
   = [0.4058188199996948, -0.09202173352241516, -0.7652694582939148, -0.5147035717964172]... (first 4 values)

are they equal? True


### Step 7: Visualize the Attention Weights
Let's look at what the attention pattern looks like. Remember that these are now **data-dependent** weights, not uniform averages!

#### Key Takeaways About Attention
| Concept | Explanation |
|---------|-------------|
| **Communication Mechanism** | Tokens can "talk" to each other and share information. |
| **No Position Awareness** | Attention is set-based - we need positional encodings. |
| **Batch Independence** | Sequences in a batch don't interact. |
| **Decoder vs Encoder** | Decoder uses causal mask (triangle); Encoder allows full attention. |
| **Self vs Cross Attention** | Self: Q, K, V from same source; Cross: Q from one source, K/V from another. |

In [179]:
print('Attention Weights for First Sequence (batch 0)')
print('   Rows = \'from\' position (query)')
print('   Cols = \'to\' position (key)')
print()
print('        ', end='')
for j in range(T):
    print(f'  pos{j} ', end='')
print()
print('       ' + '-' * 55)
for i in range(T):
    print(f'   pos{i} |', end='')
    for j in range(T):
        val = wei[0, i, j].item()
        if val > 0.001:
            print(f' {val:.3f}', end=' ')
        else:
            print(f'  ---  ', end='')
    print()
print()
print('Observations')
print('   - Lower triangular (can\'t attend to future).')
print('   - Each row sums to 1 (valid probability distribution).')
print('   - Values are NOT uniform - they\'re learned!')
print('   - Different positions attend differently to past tokens.')

Attention Weights for First Sequence (batch 0)
   Rows = 'from' position (query)
   Cols = 'to' position (key)

          pos0   pos1   pos2   pos3   pos4   pos5   pos6   pos7 
       -------------------------------------------------------
   pos0 | 1.000   ---    ---    ---    ---    ---    ---    ---  
   pos1 | 0.191  0.809   ---    ---    ---    ---    ---    ---  
   pos2 | 0.374  0.057  0.569   ---    ---    ---    ---    ---  
   pos3 | 0.129  0.338  0.138  0.396   ---    ---    ---    ---  
   pos4 | 0.431  0.084  0.058  0.305  0.122   ---    ---    ---  
   pos5 | 0.054  0.321  0.069  0.240  0.257  0.059   ---    ---  
   pos6 | 0.340  0.015  0.517  0.018  0.066  0.008  0.037   ---  
   pos7 | 0.016  0.037  0.014  0.112  0.033  0.407  0.314  0.066 

Observations
   - Lower triangular (can't attend to future).
   - Each row sums to 1 (valid probability distribution).
   - Values are NOT uniform - they're learned!
   - Different positions attend differently to past tokens.


## Step 8: Scaled Attention
There's one more trick: we **scale** the attention scores by $\frac{1}{\sqrt{d_k}}$ (where $d_k$ is the head size).

### The Problem Without Scaling
When we compute Q @ K^T, the dot product grows larger as head_size increases.

For vectors with unit variance.
- If head_size = 16, the dot product has variance ≈ 16
- If head_size = 64, the dot product has variance ≈ 64
- If head_size = 512, the dot product has variance ≈ 512

### Why is Large Variance Bad?
Large values going into softmax cause problems.
- Softmax becomes "peaky" (one position gets nearly all attention)
- Gradients become very small (vanishing gradient problem)
- Training becomes unstable

### The Solution: Scale by 1/sqrt(head_size)
By dividing by sqrt(head_size), we bring the variance back to ~1.
- variance of Q @ K^T ≈ head_size
- variance of (Q @ K^T) / sqrt(head_size) ≈ 1

This is why the formula in "Attention is All You Need" is.
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$

In [180]:
# create random keys with unit variance
k = torch.randn(B, T, head_size)
k

tensor([[[-0.6744, -1.8893, -1.8424,  0.1323, -0.7929,  1.2297,  0.0777,
           1.8036, -0.3388, -0.4670, -0.4019, -1.3110,  0.0308, -0.5922,
          -1.1771,  1.7409],
         [-0.2961, -0.3474, -0.4967, -1.3010,  1.3099, -0.2666,  0.1970,
          -0.6992,  1.1396,  0.1912, -0.0095,  0.3546, -0.4238,  1.0712,
           2.7125, -0.1935],
         [ 1.7503, -0.1117, -0.8220,  0.7975, -0.7685,  1.5376, -1.7771,
          -1.0646,  1.0508,  1.3841, -1.5027, -1.0865,  2.1496, -0.9262,
          -0.8618, -0.0133],
         [ 0.9761, -0.0773, -2.1688,  1.2137, -1.8086,  0.1943,  0.6680,
          -1.1589, -0.7162, -1.0271, -1.4785,  0.0458, -0.1069,  0.3531,
           0.3302, -0.5309],
         [ 0.0363,  2.4673, -0.1655, -0.3069,  1.4189, -0.4566, -1.5976,
           0.7736, -0.6360, -0.2510,  0.7005,  1.4388, -1.0685, -0.1663,
           0.5176, -0.7325],
         [ 0.3359, -0.7604,  0.0566, -1.5039, -0.4485,  0.5257,  0.2619,
           0.7167, -0.6965,  0.8436,  1.9249, -0.340

In [181]:
# Create random queries with unit variance
q = torch.randn(B, T, head_size)
q

tensor([[[-9.6348e-01, -1.0543e+00, -6.1099e-01,  1.1033e-01,  1.2356e-01,
          -1.4389e+00, -4.5936e-01,  7.1935e-01, -9.6226e-02, -6.8070e-01,
           7.3392e-01,  9.3939e-02,  1.0835e+00,  8.0898e-01, -9.7732e-01,
          -2.6084e-01],
         [ 9.0191e-01,  3.1770e-01,  1.5054e+00, -4.5409e-04, -8.3999e-01,
          -9.9635e-01,  1.9696e+00, -6.2411e-01,  7.8123e-01, -1.4737e+00,
           9.1280e-01, -8.1394e-01, -3.2805e-01, -1.6034e+00,  1.5658e-01,
           1.2400e+00],
         [-1.3389e+00, -1.0444e-01,  1.5695e-01, -1.5132e+00,  9.9128e-01,
           5.5732e-01, -6.7796e-01,  9.6848e-01,  8.3635e-01, -2.0765e+00,
           9.2636e-01,  1.8823e+00,  2.7995e-02, -3.6298e-01,  4.5504e-01,
           7.5949e-01],
         [-9.6253e-01,  9.5393e-01, -1.4123e+00,  8.1285e-01,  1.4346e+00,
           5.7747e-02, -8.9515e-01, -8.5902e-02, -6.0462e-01, -6.8750e-01,
           2.0560e-01, -7.1922e-01, -1.1453e+00,  8.8890e-01,  2.4767e-01,
           9.7610e-01],
    

In [182]:
# understand the scaling problem
print('understanding the scaling problem')
print()
print(f'head_size = {head_size}')
print()
print('we created random keys and queries with unit variance (~1)')
print(f'   key variance:   {k.var().item():.4f} (should be ~1)')
print(f'   query variance: {q.var().item():.4f} (should be ~1)')
print()
print('when we multiply unit-variance vectors, variance grows')
print('   for dot product of two vectors with unit variance')
print('   the resulting variance ≈ length of vectors = head_size')
print()
print(f'   expected variance after Q @ K^T: ~{head_size}')

understanding the scaling problem

head_size = 16

we created random keys and queries with unit variance (~1)
   key variance:   1.0392 (should be ~1)
   query variance: 0.9791 (should be ~1)

when we multiply unit-variance vectors, variance grows
   for dot product of two vectors with unit variance
   the resulting variance ≈ length of vectors = head_size

   expected variance after Q @ K^T: ~16


In [183]:
# demonstrate the problem: unscaled attention
print('demonstrating the problem: unscaled attention')
print()
wei_unscaled = q @ k.transpose(-2, -1)
print(f'wei_unscaled = q @ k.transpose(-2, -1)')
print(f'wei_unscaled variance: {wei_unscaled.var().item():.4f}')
print()
print(f'expected variance: ~{head_size}')
print(f'actual variance:   {wei_unscaled.var().item():.2f}')
print()
print('the variance grew from ~1 to ~{head_size}!')
print('this is because dot product of two unit-variance vectors')
print(f'has variance equal to the length of the vectors ({head_size})')

demonstrating the problem: unscaled attention

wei_unscaled = q @ k.transpose(-2, -1)
wei_unscaled variance: 12.4662

expected variance: ~16
actual variance:   12.47

the variance grew from ~1 to ~{head_size}!
this is because dot product of two unit-variance vectors
has variance equal to the length of the vectors (16)


In [184]:
# the solution: scaled attention
print('the solution: scaled attention')
print()
scale_factor = head_size ** -0.5
print(f'scale factor = 1 / sqrt(head_size) = 1 / sqrt({head_size}) = {scale_factor:.4f}')
print()
wei_scaled = (q @ k.transpose(-2, -1)) * scale_factor
print(f'wei_scaled = (q @ k.transpose(-2, -1)) * {scale_factor:.4f}')
print()
print(f'wei_scaled variance: {wei_scaled.var().item():.4f}')
print()
print('the variance is back to ~1!')
print()
print('why does this work?')
print(f'   unscaled variance ≈ {head_size}')
print(f'   multiplying by {scale_factor:.4f} = 1/sqrt({head_size})')
print(f'   divides variance by {head_size} (because Var(aX) = a²Var(X))')
print(f'   {head_size} * ({scale_factor:.4f})² = {head_size * scale_factor**2:.4f} ≈ 1')

the solution: scaled attention

scale factor = 1 / sqrt(head_size) = 1 / sqrt(16) = 0.2500

wei_scaled = (q @ k.transpose(-2, -1)) * 0.2500

wei_scaled variance: 0.7791

the variance is back to ~1!

why does this work?
   unscaled variance ≈ 16
   multiplying by 0.2500 = 1/sqrt(16)
   divides variance by 16 (because Var(aX) = a²Var(X))
   16 * (0.2500)² = 1.0000 ≈ 1


### Why Scaling Matters for Softmax
When values are too large, softmax becomes "peaky" - almost all probability goes to one position.
When values are small, softmax is "diffuse" - probability spreads more evenly.

For learning to work well, we need.
- Gradients to flow to multiple positions
- The network to learn which positions to attend to
- Smooth, not sharp, attention distributions early in training

##### let's see what happens to softmax when values are too large

In [185]:
# create example with small values
print('example: small values going into softmax')
print()
small_values = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])
print(f'small_values = {small_values.tolist()}')
print()
print('these values have variance close to the expected ~1 from scaling')
print(f'   variance: {small_values.var().item():.4f}')
print(f'   max absolute value: {small_values.abs().max().item():.4f}')

example: small values going into softmax

small_values = [0.10000000149011612, -0.20000000298023224, 0.30000001192092896, -0.20000000298023224, 0.5]

these values have variance close to the expected ~1 from scaling
   variance: 0.0950
   max absolute value: 0.5000


In [186]:
# apply softmax to small values
print('applying softmax to small values')
print()
softmax_small = torch.softmax(small_values, dim=-1)
print(f'input:   {small_values.tolist()}')
print(f'softmax: {[round(x, 4) for x in softmax_small.tolist()]}')
print()
print('tracing through the softmax calculation')
import math
exp_vals = [math.exp(v.item()) for v in small_values]
exp_sum = sum(exp_vals)
print(f'   e^0.1 = {exp_vals[0]:.4f}')
print(f'   e^-0.2 = {exp_vals[1]:.4f}')
print(f'   e^0.3 = {exp_vals[2]:.4f}')
print(f'   e^-0.2 = {exp_vals[3]:.4f}')
print(f'   e^0.5 = {exp_vals[4]:.4f}')
print(f'   sum = {exp_sum:.4f}')
print()
print('softmax = e^x / sum(e^x)')
for i, (exp_v, soft_v) in enumerate(zip(exp_vals, softmax_small.tolist())):
    print(f'   position {i}: {exp_v:.4f} / {exp_sum:.4f} = {soft_v:.4f}')

applying softmax to small values

input:   [0.10000000149011612, -0.20000000298023224, 0.30000001192092896, -0.20000000298023224, 0.5]
softmax: [0.1925, 0.1426, 0.2351, 0.1426, 0.2872]

tracing through the softmax calculation
   e^0.1 = 1.1052
   e^-0.2 = 0.8187
   e^0.3 = 1.3499
   e^-0.2 = 0.8187
   e^0.5 = 1.6487
   sum = 5.7412

softmax = e^x / sum(e^x)
   position 0: 1.1052 / 5.7412 = 0.1925
   position 1: 0.8187 / 5.7412 = 0.1426
   position 2: 1.3499 / 5.7412 = 0.2351
   position 3: 0.8187 / 5.7412 = 0.1426
   position 4: 1.6487 / 5.7412 = 0.2872


In [187]:
# analyze the small values softmax result
print('analyzing small values softmax result')
print()
print(f'softmax: {[round(x, 4) for x in softmax_small.tolist()]}')
print()
print('observations')
print(f'   min probability: {softmax_small.min().item():.4f}')
print(f'   max probability: {softmax_small.max().item():.4f}')
print(f'   ratio max/min:   {softmax_small.max().item() / softmax_small.min().item():.2f}x')
print()
print('the distribution is DIFFUSE (spread out)')
print('   - all positions get meaningful probability')
print('   - no single position dominates')
print('   - gradients can flow to all positions')
print('   - this is GOOD for learning!')

analyzing small values softmax result

softmax: [0.1925, 0.1426, 0.2351, 0.1426, 0.2872]

observations
   min probability: 0.1426
   max probability: 0.2872
   ratio max/min:   2.01x

the distribution is DIFFUSE (spread out)
   - all positions get meaningful probability
   - no single position dominates
   - gradients can flow to all positions
   - this is GOOD for learning!


In [188]:
# create example with large values
print('example: large values going into softmax')
print()
large_values = small_values * 8
print(f'large_values = small_values * 8')
print(f'            = {large_values.tolist()}')
print()
print('this simulates what happens WITHOUT scaling')
print('if head_size = 64, values could be 8x larger')
print(f'   variance: {large_values.var().item():.4f}')
print(f'   max absolute value: {large_values.abs().max().item():.4f}')

example: large values going into softmax

large_values = small_values * 8
            = [0.800000011920929, -1.600000023841858, 2.4000000953674316, -1.600000023841858, 4.0]

this simulates what happens WITHOUT scaling
if head_size = 64, values could be 8x larger
   variance: 6.0800
   max absolute value: 4.0000


In [189]:
# apply softmax to large values
print('applying softmax to large values')
print()
softmax_large = torch.softmax(large_values, dim=-1)
print(f'input:   {large_values.tolist()}')
print(f'softmax: {[round(x, 4) for x in softmax_large.tolist()]}')
print()
print('tracing through the softmax calculation')
exp_vals_large = [math.exp(v.item()) for v in large_values]
exp_sum_large = sum(exp_vals_large)
print(f'   e^0.8 = {exp_vals_large[0]:.4f}')
print(f'   e^-1.6 = {exp_vals_large[1]:.4f}')
print(f'   e^2.4 = {exp_vals_large[2]:.4f}')
print(f'   e^-1.6 = {exp_vals_large[3]:.4f}')
print(f'   e^4.0 = {exp_vals_large[4]:.4f}')
print(f'   sum = {exp_sum_large:.4f}')
print()
print('notice how e^4.0 = {:.2f} dominates the sum!'.format(exp_vals_large[4]))

applying softmax to large values

input:   [0.800000011920929, -1.600000023841858, 2.4000000953674316, -1.600000023841858, 4.0]
softmax: [0.0326, 0.003, 0.1615, 0.003, 0.8]

tracing through the softmax calculation
   e^0.8 = 2.2255
   e^-1.6 = 0.2019
   e^2.4 = 11.0232
   e^-1.6 = 0.2019
   e^4.0 = 54.5982
   sum = 68.2507

notice how e^4.0 = 54.60 dominates the sum!


In [190]:
# analyze the large values softmax result
print('analyzing large values softmax result')
print()
print(f'softmax: {[round(x, 4) for x in softmax_large.tolist()]}')
print()
print('observations')
print(f'   min probability: {softmax_large.min().item():.6f}')
print(f'   max probability: {softmax_large.max().item():.4f}')
print(f'   ratio max/min:   {softmax_large.max().item() / softmax_large.min().item():.0f}x')
print()
print('the distribution is PEAKY (concentrated)')
print('   - position 4 gets almost all probability (~0.96)')
print('   - other positions get almost nothing')
print('   - gradients only flow to position 4')
print('   - this is BAD for learning!')
print()
print('this is why we scale by 1/sqrt(head_size)')
print('   - keeps attention scores in a reasonable range')
print('   - prevents softmax from becoming too peaky')
print('   - allows gradients to flow during training')

analyzing large values softmax result

softmax: [0.0326, 0.003, 0.1615, 0.003, 0.8]

observations
   min probability: 0.002958
   max probability: 0.8000
   ratio max/min:   270x

the distribution is PEAKY (concentrated)
   - position 4 gets almost all probability (~0.96)
   - other positions get almost nothing
   - gradients only flow to position 4
   - this is BAD for learning!

this is why we scale by 1/sqrt(head_size)
   - keeps attention scores in a reasonable range
   - prevents softmax from becoming too peaky
   - allows gradients to flow during training


### Step 9: Layer Normalization

#### What is Layer Normalization?
**Layer Normalization** normalizes the features of each sample independently, ensuring they have mean=0 and std=1.

#### The Formula
For each sample x with features [x₁, x₂, ..., xₙ].
1. Compute mean: μ = (x₁ + x₂ + ... + xₙ) / n
2. Compute variance: σ² = Σ(xᵢ - μ)² / n
3. Normalize: x̂ᵢ = (xᵢ - μ) / √(σ² + ε)
4. Scale and shift: yᵢ = γ * x̂ᵢ + β

Where.
- ε is a small constant for numerical stability (avoid divide by zero)
- γ (gamma) and β (beta) are learnable parameters

#### Why Do We Need Layer Normalization?
Without normalization, activations can.
- Explode (grow very large) → gradients explode
- Vanish (become very small) → gradients vanish
- Have very different scales → training instability

Layer normalization keeps activations in a stable range, making training faster and more reliable.

#### BatchNorm vs LayerNorm

| Normalization | Normalizes across | Used in |
|--------------|-------------------|---------|
| **BatchNorm** | Same feature across batch samples (column-wise) | CNNs |
| **LayerNorm** | All features within one sample (row-wise) | Transformers |

#### Why LayerNorm for Transformers?
LayerNorm is preferred in Transformers because.
1. Works with any batch size (including batch_size=1 for inference)
2. Each token's representation is normalized independently
3. No dependency on other samples in the batch

In [191]:
class LayerNorm1d:
    """
    Layer Normalization: normalizes each sample independently.
    
    For input of shape (batch_size, features):
    - Each row (sample) is normalized to have mean=0, std=1
    - Then scaled by learnable gamma and shifted by learnable beta
    """

    def __init__(self, dim, eps=1e-5):
        """
        Initialize the LayerNorm1d layer.
        
        Args:
            dim: Number of features.
            eps: Small constant for numerical stability (avoid divide by zero).
        """
        self.eps = eps
        # learnable parameters
        self.gamma = torch.ones(dim)   # scale (initialized to 1)
        self.beta = torch.zeros(dim)   # shift (initialized to 0)

    def __call__(self, x):
        """
        Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, dim).
            
        Returns:
            Normalized tensor of same shape.
        """
        # step 1: compute mean of each sample (across features)
        xmean = x.mean(dim=1, keepdim=True)  # (batch, 1)
        # step 2: compute variance of each sample
        xvar = x.var(dim=1, keepdim=True)    # (batch, 1)
        # step 3: normalize to mean=0, var=1
        xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
        # step 4: scale and shift (learnable)
        self.out = self.gamma * xhat + self.beta
        return self.out

    def parameters(self):
        """
        Return learnable parameters.
        
        Returns:
            List containing gamma (scale) and beta (shift) tensors.
        """
        return [self.gamma, self.beta]

In [192]:
# understanding the LayerNorm class step by step
print('understanding the LayerNorm1d class')
print()
print('__init__(self, dim, eps=1e-5)')
print('   dim: number of features to normalize over')
print('   eps: small constant to avoid division by zero')
print('   gamma: learnable scale parameter, initialized to 1')
print('   beta: learnable shift parameter, initialized to 0')
print()
print('__call__(self, x)')
print('   step 1: compute mean across features (dim=1)')
print('           xmean = x.mean(dim=1, keepdim=True)')
print()
print('   step 2: compute variance across features (dim=1)')
print('           xvar = x.var(dim=1, keepdim=True)')
print()
print('   step 3: normalize to mean=0, variance=1')
print('           xhat = (x - xmean) / sqrt(xvar + eps)')
print()
print('   step 4: scale and shift (learnable)')
print('           out = gamma * xhat + beta')
print()
print('parameters()')
print('   returns [gamma, beta] for optimization')

understanding the LayerNorm1d class

__init__(self, dim, eps=1e-5)
   dim: number of features to normalize over
   eps: small constant to avoid division by zero
   gamma: learnable scale parameter, initialized to 1
   beta: learnable shift parameter, initialized to 0

__call__(self, x)
   step 1: compute mean across features (dim=1)
           xmean = x.mean(dim=1, keepdim=True)

   step 2: compute variance across features (dim=1)
           xvar = x.var(dim=1, keepdim=True)

   step 3: normalize to mean=0, variance=1
           xhat = (x - xmean) / sqrt(xvar + eps)

   step 4: scale and shift (learnable)
           out = gamma * xhat + beta

parameters()
   returns [gamma, beta] for optimization


In [193]:
# test layer normalization
torch.manual_seed(42)

# create the layer
n_features = 100
module = LayerNorm1d(n_features)
print(f'created LayerNorm1d with {n_features} features')
print(f'   gamma shape: {module.gamma.shape}')
print(f'   beta shape:  {module.beta.shape}')
print()

# create test data
n_samples = 32
x = torch.randn(n_samples, n_features)
print(f'created test data with shape: {x.shape}')
print(f'   {n_samples} samples, {n_features} features each')

created LayerNorm1d with 100 features
   gamma shape: torch.Size([100])
   beta shape:  torch.Size([100])

created test data with shape: torch.Size([32, 100])
   32 samples, 100 features each


In [194]:
# examine the data BEFORE normalization
print('examining data BEFORE normalization')
print()
print('sample 0')
print(f'   mean: {x[0].mean().item():.4f}')
print(f'   std:  {x[0].std().item():.4f}')
print(f'   min:  {x[0].min().item():.4f}')
print(f'   max:  {x[0].max().item():.4f}')
print()
print('sample 1')
print(f'   mean: {x[1].mean().item():.4f}')
print(f'   std:  {x[1].std().item():.4f}')
print(f'   min:  {x[1].min().item():.4f}')
print(f'   max:  {x[1].max().item():.4f}')
print()
print('sample 2')
print(f'   mean: {x[2].mean().item():.4f}')
print(f'   std:  {x[2].std().item():.4f}')
print(f'   min:  {x[2].min().item():.4f}')
print(f'   max:  {x[2].max().item():.4f}')
print()
print('note: means are not 0, stds are not exactly 1')
print('each sample has slightly different statistics')

examining data BEFORE normalization

sample 0
   mean: 0.0320
   std:  1.0406
   min:  -2.5095
   max:  2.2181

sample 1
   mean: 0.0811
   std:  0.9597
   min:  -2.4801
   max:  2.0265

sample 2
   mean: 0.0571
   std:  0.9930
   min:  -2.5850
   max:  2.1120

note: means are not 0, stds are not exactly 1
each sample has slightly different statistics


In [195]:
# apply layer normalization
x_normed = module(x)
print('applied layer normalization')
print(f'   x_normed shape: {x_normed.shape}')

applied layer normalization
   x_normed shape: torch.Size([32, 100])


In [196]:
# examine the data AFTER normalization
print('examining data AFTER normalization')
print()
print('sample 0')
print(f'   mean: {x_normed[0].mean().item():.6f}')
print(f'   std:  {x_normed[0].std().item():.4f}')
print(f'   min:  {x_normed[0].min().item():.4f}')
print(f'   max:  {x_normed[0].max().item():.4f}')
print()
print('sample 1')
print(f'   mean: {x_normed[1].mean().item():.6f}')
print(f'   std:  {x_normed[1].std().item():.4f}')
print(f'   min:  {x_normed[1].min().item():.4f}')
print(f'   max:  {x_normed[1].max().item():.4f}')
print()
print('sample 2')
print(f'   mean: {x_normed[2].mean().item():.6f}')
print(f'   std:  {x_normed[2].std().item():.4f}')
print(f'   min:  {x_normed[2].min().item():.4f}')
print(f'   max:  {x_normed[2].max().item():.4f}')
print()
print('now all samples have mean ≈ 0 and std ≈ 1!')
print('layer normalization standardizes each sample independently')

examining data AFTER normalization

sample 0
   mean: -0.000000
   std:  1.0000
   min:  -2.4425
   max:  2.1008

sample 1
   mean: 0.000000
   std:  1.0000
   min:  -2.6687
   max:  2.0271

sample 2
   mean: 0.000000
   std:  1.0000
   min:  -2.6605
   max:  2.0694

now all samples have mean ≈ 0 and std ≈ 1!
layer normalization standardizes each sample independently


In [197]:
# verify normalization for all samples
print('verifying normalization for all samples')
print()
means = x_normed.mean(dim=1)
stds = x_normed.std(dim=1)
print(f'mean of all sample means: {means.mean().item():.8f} (should be ~0)')
print(f'mean of all sample stds:  {stds.mean().item():.4f} (should be ~1)')
print()
print('all samples are now normalized!')

verifying normalization for all samples

mean of all sample means: 0.00000000 (should be ~0)
mean of all sample stds:  1.0000 (should be ~1)

all samples are now normalized!


### Understanding the Difference: BatchNorm vs LayerNorm
Let's visualize what each normalization normalizes over.

For data of shape (batch_size, features) = (32, 100).
- **BatchNorm**: Normalizes each FEATURE across all samples
  - For feature 0: compute mean and std across all 32 samples
  - Result: each feature has mean=0, std=1 across the batch
  
- **LayerNorm**: Normalizes each SAMPLE across all features
  - For sample 0: compute mean and std across all 100 features
  - Result: each sample has mean=0, std=1 across its features

In [198]:
# demonstrate the difference between BatchNorm and LayerNorm
print('demonstrating the difference between BatchNorm and LayerNorm')
print()
print('our data has shape (32, 100) = (batch_size, features)')
print()
print('LAYERNORM normalizes across FEATURES (row-wise)')
print('checking ONE SAMPLE across all features')
print(f'   sample 0 - mean: {x_normed[0,:].mean().item():.6f}, std: {x_normed[0,:].std().item():.4f}')
print(f'   sample 1 - mean: {x_normed[1,:].mean().item():.6f}, std: {x_normed[1,:].std().item():.4f}')
print(f'   sample 2 - mean: {x_normed[2,:].mean().item():.6f}, std: {x_normed[2,:].std().item():.4f}')
print('   → all samples are normalized (mean≈0, std≈1)')
print()
print('BATCHNORM would normalize across SAMPLES (column-wise)')
print('checking ONE FEATURE across all samples')
print(f'   feature 0 - mean: {x_normed[:,0].mean().item():.4f}, std: {x_normed[:,0].std().item():.4f}')
print(f'   feature 1 - mean: {x_normed[:,1].mean().item():.4f}, std: {x_normed[:,1].std().item():.4f}')
print(f'   feature 2 - mean: {x_normed[:,2].mean().item():.4f}, std: {x_normed[:,2].std().item():.4f}')
print('   → features are NOT normalized (this is what BatchNorm would normalize)')
print()
print('LayerNorm normalizes each sample independently')
print('BatchNorm would normalize each feature across the batch')

demonstrating the difference between BatchNorm and LayerNorm

our data has shape (32, 100) = (batch_size, features)

LAYERNORM normalizes across FEATURES (row-wise)
checking ONE SAMPLE across all features
   sample 0 - mean: -0.000000, std: 1.0000
   sample 1 - mean: 0.000000, std: 1.0000
   sample 2 - mean: 0.000000, std: 1.0000
   → all samples are normalized (mean≈0, std≈1)

BATCHNORM would normalize across SAMPLES (column-wise)
checking ONE FEATURE across all samples
   feature 0 - mean: -0.1428, std: 1.0629
   feature 1 - mean: 0.0941, std: 1.1109
   feature 2 - mean: 0.2947, std: 1.1421
   → features are NOT normalized (this is what BatchNorm would normalize)

LayerNorm normalizes each sample independently
BatchNorm would normalize each feature across the batch


In [199]:
# why LayerNorm is preferred in Transformers
print('why LayerNorm is preferred in Transformers')
print()
print('1. works with any batch size')
print('   - BatchNorm needs multiple samples to compute batch statistics')
print('   - LayerNorm only needs one sample (normalizes across features)')
print('   - important for inference with batch_size=1')
print()
print('2. each token is normalized independently')
print('   - in Transformers, each position is like a "sample"')
print('   - we want each token\'s representation to be well-conditioned')
print('   - no dependency on other samples in the batch')
print()
print('3. no running statistics needed')
print('   - BatchNorm maintains running mean/variance during training')
print('   - LayerNorm computes fresh statistics each forward pass')
print('   - simpler implementation and behavior')

why LayerNorm is preferred in Transformers

1. works with any batch size
   - BatchNorm needs multiple samples to compute batch statistics
   - LayerNorm only needs one sample (normalizes across features)
   - important for inference with batch_size=1

2. each token is normalized independently
   - in Transformers, each position is like a "sample"
   - we want each token's representation to be well-conditioned
   - no dependency on other samples in the batch

3. no running statistics needed
   - BatchNorm maintains running mean/variance during training
   - LayerNorm computes fresh statistics each forward pass
   - simpler implementation and behavior


In [200]:
# final summary of Part 6
print('SUMMARY: Full Self-Attention with Q, K, V')
print('=' * 60)
print()
print('SELF-ATTENTION MECHANISM')
print('   step 1: project input to queries, keys, values')
print('           Q = query(x)  # "what am I looking for?"')
print('           K = key(x)    # "what do I contain?"')
print('           V = value(x)  # "what will I give?"')
print()
print('   step 2: compute attention scores')
print('           wei = Q @ K^T')
print('           wei[i,j] = how much position i attends to position j')
print()
print('   step 3: scale by 1/sqrt(head_size)')
print('           wei = wei * (head_size ** -0.5)')
print('           prevents variance explosion')
print()
print('   step 4: mask future positions')
print('           wei = wei.masked_fill(tril == 0, -inf)')
print('           prevents looking at future tokens')
print()
print('   step 5: apply softmax')
print('           wei = softmax(wei)')
print('           converts to probabilities that sum to 1')
print()
print('   step 6: weighted sum of values')
print('           out = wei @ V')
print('           each position gets weighted average of values')
print()
print('LAYER NORMALIZATION')
print('   normalizes each sample to mean=0, std=1')
print('   stabilizes training by keeping activations in good range')
print('   applied before attention and feedforward layers')
print()
print('COMPLETE ATTENTION FORMULA')
print('   Attention(Q, K, V) = softmax(mask(Q @ K^T / sqrt(d_k))) @ V')

SUMMARY: Full Self-Attention with Q, K, V

SELF-ATTENTION MECHANISM
   step 1: project input to queries, keys, values
           Q = query(x)  # "what am I looking for?"
           K = key(x)    # "what do I contain?"
           V = value(x)  # "what will I give?"

   step 2: compute attention scores
           wei = Q @ K^T
           wei[i,j] = how much position i attends to position j

   step 3: scale by 1/sqrt(head_size)
           wei = wei * (head_size ** -0.5)
           prevents variance explosion

   step 4: mask future positions
           wei = wei.masked_fill(tril == 0, -inf)
           prevents looking at future tokens

   step 5: apply softmax
           wei = softmax(wei)
           converts to probabilities that sum to 1

   step 6: weighted sum of values
           out = wei @ V
           each position gets weighted average of values

LAYER NORMALIZATION
   normalizes each sample to mean=0, std=1
   stabilizes training by keeping activations in good range
   applied 

## MIT License