In [None]:
import numpy as np
import math

# Small example for visualization
sequence_length = 8
head_dimension = 4

np.random.seed(42)
query = np.random.randn(sequence_length, head_dimension).astype(np.float32)
key = np.random.randn(sequence_length, head_dimension).astype(np.float32)
value = np.random.randn(sequence_length, head_dimension).astype(np.float32)

print(f"query shape: {query.shape}")
print(f"key shape: {key.shape}")
print(f"value shape: {value.shape}")

In [None]:
# =============================================================================
# NAIVE ATTENTION - materializes the full (sequence_length x sequence_length) matrix
# =============================================================================

def naive_attention(query, key, value):
    """
    Standard attention: O = softmax(Q @ K^T / sqrt(d)) @ V
    
    This materializes the full (N x N) attention matrix - what FlashAttention avoids.
    """
    scale = 1.0 / math.sqrt(query.shape[-1])
    
    # Step 1: Compute all scores at once -> (sequence_length x sequence_length) matrix
    scores = query @ key.T * scale
    
    # Step 2: Stable softmax (subtract max for numerical stability)
    scores_max = scores.max(axis=1, keepdims=True)
    exp_scores = np.exp(scores - scores_max)
    attention_weights = exp_scores / exp_scores.sum(axis=1, keepdims=True)
    
    # Step 3: Weighted sum of values
    output = attention_weights @ value
    
    return output, attention_weights

naive_output, attention_weights = naive_attention(query, key, value)
print("Naive attention output shape:", naive_output.shape)
print("\nAttention weights (the N×N matrix we want to avoid storing):")
print(attention_weights.round(3))

## The Online Softmax Trick

The key insight: we can compute softmax **incrementally** without seeing all scores at once.

**Problem**: For softmax, we need `exp(score_i) / sum(exp(all_scores))`. But we're processing scores in blocks!

**Solution**: Maintain two running statistics per query row:
- `running_max`: the maximum score seen so far
- `running_sum_exp`: the sum of `exp(score - running_max)` for all scores seen

When we see a new block of scores, we can **merge** the statistics:

```
new_max = max(running_max, block_max)
new_sum_exp = exp(running_max - new_max) * running_sum_exp + exp(block_max - new_max) * block_sum_exp
```

The output also needs rescaling when the max changes - this is the "correction factor".

In [19]:
# =============================================================================
# DEMO: Online softmax on a single row of scores
# =============================================================================
# Let's prove the online softmax trick works before using it in flash attention

scores_row = np.array([1.0, 3.0, 2.0, 4.0, 0.5, 2.5])  # pretend these are Q[i] @ K^T
block_size = 2

# Ground truth: compute softmax the normal way
scores_max = scores_row.max()
ground_truth = np.exp(scores_row - scores_max) / np.exp(scores_row - scores_max).sum()
print("Ground truth softmax:", ground_truth.round(4))

# Online computation: process in blocks of 2
running_max = -np.inf
running_sum_exp = 0.0

for block_start in range(0, len(scores_row), block_size):
    block_scores = scores_row[block_start : block_start + block_size]
    
    # Compute block statistics
    block_max = block_scores.max()
    block_sum_exp = np.exp(block_scores - block_max).sum()
    
    # Merge with running statistics
    new_max = max(running_max, block_max)
    
    # Key insight: rescale both old and new sums to the new max
    old_contribution = np.exp(running_max - new_max) * running_sum_exp
    new_contribution = np.exp(block_max - new_max) * block_sum_exp
    new_sum_exp = old_contribution + new_contribution
    
    running_max = new_max
    running_sum_exp = new_sum_exp
    
    print(f"After block {block_scores}: max={running_max:.2f}, sum_exp={running_sum_exp:.4f}")

# Now we can compute the final softmax
online_softmax = np.exp(scores_row - running_max) / running_sum_exp
print("\nOnline softmax:      ", online_softmax.round(4))
print("Match:", np.allclose(ground_truth, online_softmax))

Ground truth softmax: [0.0276 0.2037 0.0749 0.5536 0.0167 0.1235]
After block [1. 3.]: max=3.00, sum_exp=1.1353
After block [2. 4.]: max=4.00, sum_exp=1.5530
After block [0.5 2.5]: max=4.00, sum_exp=1.8063

Online softmax:       [0.0276 0.2037 0.0749 0.5536 0.0167 0.1235]
Match: True


## Flash Attention: Combining Tiling + Online Softmax

Now the tricky part: we need to maintain the **running output** too, not just softmax statistics.

For each query row, we maintain:
- `running_max`: maximum score seen so far  
- `running_sum_exp`: sum of exp(scores - running_max)
- `running_output`: the weighted sum, but we need to **rescale it** when max changes!

The update formula when processing a new block:
```
new_max = max(running_max, block_max)

# Correction factors for changing the "coordinate system"
old_scale = exp(running_max - new_max)  
new_scale = exp(block_max - new_max)

new_sum_exp = old_scale * running_sum_exp + new_scale * block_sum_exp

# The output update - this is the key insight!
new_output = (old_scale * running_sum_exp * running_output + new_scale * block_weighted_values) / new_sum_exp
```

In [None]:
# =============================================================================
# FLASH ATTENTION - never materializes the full (N x N) matrix
# =============================================================================

def flash_attention(query, key, value, block_size_query=2, block_size_key=2):
    """
    Flash Attention: computes exact attention without storing the full attention matrix.
    
    Key ideas:
    1. Process query and key/value in small blocks (tiles)
    2. Use online softmax to combine results from different blocks correctly
    
    Args:
        query: (sequence_length, head_dimension)
        key: (sequence_length, head_dimension)  
        value: (sequence_length, head_dimension)
        block_size_query: number of query rows to process together
        block_size_key: number of key rows to process together
    """
    sequence_length, head_dimension = query.shape
    scale = 1.0 / math.sqrt(head_dimension)
    
    # Initialize output and running statistics for ALL query rows
    output = np.zeros((sequence_length, head_dimension), dtype=np.float64)
    running_max = np.full(sequence_length, -np.inf, dtype=np.float64)
    running_sum_exp = np.zeros(sequence_length, dtype=np.float64)
    
    # Outer loop: iterate over blocks of keys/values
    for key_block_start in range(0, sequence_length, block_size_key):
        key_block_end = min(key_block_start + block_size_key, sequence_length)
        
        # Load one block of keys and values (in real GPU: load to SRAM)
        key_block = key[key_block_start:key_block_end].astype(np.float64)
        value_block = value[key_block_start:key_block_end].astype(np.float64)
        
        # Inner loop: iterate over blocks of queries
        for query_block_start in range(0, sequence_length, block_size_query):
            query_block_end = min(query_block_start + block_size_query, sequence_length)
            
            # Load one block of queries
            query_block = query[query_block_start:query_block_end].astype(np.float64)
            
            # Get current running statistics for this query block
            current_max = running_max[query_block_start:query_block_end]
            current_sum_exp = running_sum_exp[query_block_start:query_block_end]
            current_output = output[query_block_start:query_block_end]
            
            # =================================================================
            # STEP 1: Compute scores for this (query_block x key_block) tile
            # =================================================================
            # This is the ONLY score matrix we compute - size (block_size_query x block_size_key)
            # NOT (sequence_length x sequence_length)!
            scores_block = (query_block @ key_block.T) * scale
            
            # =================================================================
            # STEP 2: Compute LOCAL softmax statistics for this block
            # =================================================================
            block_max = scores_block.max(axis=1)  # max over keys, shape: (block_size_query,)
            block_exp_scores = np.exp(scores_block - block_max[:, None])
            block_sum_exp = block_exp_scores.sum(axis=1)  # shape: (block_size_query,)
            
            # Local weighted sum of values (not yet normalized)
            block_weighted_values = block_exp_scores @ value_block  # shape: (block_size_query, head_dimension)
            
            # =================================================================
            # STEP 3: MERGE with running statistics (the online softmax trick!)
            # =================================================================
            new_max = np.maximum(current_max, block_max)
            
            # Correction factors: rescale everything to the new maximum
            old_scale = np.exp(current_max - new_max)  # how much to scale old statistics
            new_scale = np.exp(block_max - new_max)    # how much to scale new block
            
            new_sum_exp = old_scale * current_sum_exp + new_scale * block_sum_exp
            
            # Update output: rescale old output and add new contribution, then normalize
            # old_scale * current_sum_exp * current_output = old unnormalized weighted sum (rescaled)
            # new_scale * block_weighted_values = new unnormalized weighted sum (rescaled)
            new_output = (
                old_scale[:, None] * current_sum_exp[:, None] * current_output +
                new_scale[:, None] * block_weighted_values
            ) / new_sum_exp[:, None]
            
            # =================================================================
            # STEP 4: Write back updated statistics
            # =================================================================
            output[query_block_start:query_block_end] = new_output
            running_max[query_block_start:query_block_end] = new_max
            running_sum_exp[query_block_start:query_block_end] = new_sum_exp
    
    return output.astype(query.dtype)

In [24]:
# =============================================================================
# COMPARISON: Verify flash attention matches naive attention
# =============================================================================

flash_output = flash_attention(query, key, value, block_size_query=2, block_size_key=2)

print("Naive output:")
print(naive_output.round(4))
print("\nFlash output:")
print(flash_output.round(4))
print("\nMax absolute error:", np.abs(naive_output - flash_output).max())
print("Results match:", np.allclose(naive_output, flash_output, rtol=1e-5))

Naive output:
[[-0.2606  0.1387  0.187   0.178 ]
 [-0.2657  0.1958  0.3021  0.1886]
 [-0.2045  0.1917 -0.1595 -0.1489]
 [ 0.129  -0.2446  0.2879  0.6396]
 [-0.2408  0.0981 -0.2086 -0.4503]
 [ 0.058   0.6045 -0.2326  0.2756]
 [-0.1434 -0.0073 -0.0186 -0.1872]
 [-0.1516 -0.0692  0.1787 -0.2666]]

Flash output:
[[-0.2606  0.1387  0.187   0.178 ]
 [-0.2657  0.1958  0.3021  0.1886]
 [-0.2045  0.1917 -0.1595 -0.1489]
 [ 0.129  -0.2446  0.2879  0.6396]
 [-0.2408  0.0981 -0.2086 -0.4503]
 [ 0.058   0.6045 -0.2326  0.2756]
 [-0.1434 -0.0073 -0.0186 -0.1872]
 [-0.1516 -0.0692  0.1787 -0.2666]]

Max absolute error: 5.9604645e-08
Results match: True


## Memory Analysis

**Naive attention** stores:
- Full score matrix: `(sequence_length × sequence_length)` 
- Full attention weights: `(sequence_length × sequence_length)`

**Flash attention** stores:
- One block of scores: `(block_size_query × block_size_key)` 
- Running statistics: `O(sequence_length)` for max and sum_exp

For sequence_length=4096 with block_size=64:
- Naive: 4096² × 2 = **33 million** floats
- Flash: 64² + 4096×2 = **~12 thousand** floats

That's a **2700x** memory reduction!

In [None]:
# =============================================================================
# Test with larger dimensions and various block sizes
# =============================================================================

np.random.seed(123)
large_sequence_length = 128
large_head_dimension = 64

large_query = np.random.randn(large_sequence_length, large_head_dimension).astype(np.float32)
large_key = np.random.randn(large_sequence_length, large_head_dimension).astype(np.float32)
large_value = np.random.randn(large_sequence_length, large_head_dimension).astype(np.float32)

large_naive_output, _ = naive_attention(large_query, large_key, large_value)

# Test various block sizes
for block_size in [8, 16, 32, 64]:
    large_flash_output = flash_attention(
        large_query, large_key, large_value,
        block_size_query=block_size, 
        block_size_key=block_size
    )
    max_error = np.abs(large_naive_output - large_flash_output).max()
    print(f"Block size {block_size:2d}: max error = {max_error:.2e}")

In [14]:
import numpy as np
import math

# Small example for visualization
sequence_length = 8
head_dimension = 4

np.random.seed(42)
query = np.random.randn(sequence_length, head_dimension).astype(np.float32)
key = np.random.randn(sequence_length, head_dimension).astype(np.float32)
value = np.random.randn(sequence_length, head_dimension).astype(np.float32)

print(f"query shape: {query.shape}")
print(f"key shape: {key.shape}")
print(f"value shape: {value.shape}")

query shape: (8, 4)
key shape: (8, 4)
value shape: (8, 4)


In [None]:
# NAIVE ATTENTION
def naive_attention(query, key, value):
    scale = 1.0 / math.sqrt(query.shape[-1])
    scores = query @ key.T * scale
    scores_max = scores.max(axis=1, keepdims=True)
    exp_scores = np.exp(scores - scores_max)
    attention_weights = exp_scores / exp_scores.sum(axis=1, keepdims=True)
    output = attention_weights @ value
    return output, attention_weights

naive_output, attention_weights = naive_attention(query, key, value)
print("Naive attention output shape:", naive_output.shape)
print("\nAttention weights (the NxN matrix we want to avoid storing):")
print(attention_weights.round(3))

Naive attention output shape: (8, 4)

Attention weights (the N×N matrix we want to avoid storing):
[[0.059 0.098 0.098 0.151 0.11  0.248 0.209 0.027]
 [0.124 0.042 0.067 0.124 0.121 0.298 0.198 0.026]
 [0.108 0.095 0.127 0.133 0.076 0.115 0.108 0.238]
 [0.055 0.577 0.032 0.053 0.138 0.005 0.02  0.119]
 [0.125 0.095 0.083 0.1   0.065 0.045 0.057 0.43 ]
 [0.261 0.115 0.198 0.016 0.199 0.029 0.027 0.155]
 [0.048 0.192 0.085 0.232 0.065 0.083 0.124 0.173]
 [0.02  0.152 0.042 0.361 0.05  0.133 0.205 0.037]]


In [16]:
# DEMO: Online softmax on a single row of scores
scores_row = np.array([1.0, 3.0, 2.0, 4.0, 0.5, 2.5])
block_size = 2

# Ground truth
scores_max = scores_row.max()
ground_truth = np.exp(scores_row - scores_max) / np.exp(scores_row - scores_max).sum()
print("Ground truth softmax:", ground_truth.round(4))

# Online computation
running_max = -np.inf
running_sum_exp = 0.0

for block_start in range(0, len(scores_row), block_size):
    block_scores = scores_row[block_start : block_start + block_size]
    block_max = block_scores.max()
    block_sum_exp = np.exp(block_scores - block_max).sum()
    
    new_max = max(running_max, block_max)
    old_contribution = np.exp(running_max - new_max) * running_sum_exp
    new_contribution = np.exp(block_max - new_max) * block_sum_exp
    new_sum_exp = old_contribution + new_contribution
    
    running_max = new_max
    running_sum_exp = new_sum_exp
    print(f"After block {block_scores}: max={running_max:.2f}, sum_exp={running_sum_exp:.4f}")

online_softmax = np.exp(scores_row - running_max) / running_sum_exp
print("\nOnline softmax:      ", online_softmax.round(4))
print("Match:", np.allclose(ground_truth, online_softmax))

Ground truth softmax: [0.0276 0.2037 0.0749 0.5536 0.0167 0.1235]
After block [1. 3.]: max=3.00, sum_exp=1.1353
After block [2. 4.]: max=4.00, sum_exp=1.5530
After block [0.5 2.5]: max=4.00, sum_exp=1.8063

Online softmax:       [0.0276 0.2037 0.0749 0.5536 0.0167 0.1235]
Match: True


In [17]:
# FLASH ATTENTION
def flash_attention(query, key, value, block_size_query=2, block_size_key=2):
    sequence_length, head_dimension = query.shape
    scale = 1.0 / math.sqrt(head_dimension)
    
    output = np.zeros((sequence_length, head_dimension), dtype=np.float64)
    running_max = np.full(sequence_length, -np.inf, dtype=np.float64)
    running_sum_exp = np.zeros(sequence_length, dtype=np.float64)
    
    for key_block_start in range(0, sequence_length, block_size_key):
        key_block_end = min(key_block_start + block_size_key, sequence_length)
        key_block = key[key_block_start:key_block_end].astype(np.float64)
        value_block = value[key_block_start:key_block_end].astype(np.float64)
        
        for query_block_start in range(0, sequence_length, block_size_query):
            query_block_end = min(query_block_start + block_size_query, sequence_length)
            query_block = query[query_block_start:query_block_end].astype(np.float64)
            
            current_max = running_max[query_block_start:query_block_end]
            current_sum_exp = running_sum_exp[query_block_start:query_block_end]
            current_output = output[query_block_start:query_block_end]
            
            scores_block = (query_block @ key_block.T) * scale
            
            block_max = scores_block.max(axis=1)
            block_exp_scores = np.exp(scores_block - block_max[:, None])
            block_sum_exp = block_exp_scores.sum(axis=1)
            block_weighted_values = block_exp_scores @ value_block
            
            new_max = np.maximum(current_max, block_max)
            old_scale = np.exp(current_max - new_max)
            new_scale = np.exp(block_max - new_max)
            new_sum_exp = old_scale * current_sum_exp + new_scale * block_sum_exp
            
            new_output = (
                old_scale[:, None] * current_sum_exp[:, None] * current_output +
                new_scale[:, None] * block_weighted_values
            ) / new_sum_exp[:, None]
            
            output[query_block_start:query_block_end] = new_output
            running_max[query_block_start:query_block_end] = new_max
            running_sum_exp[query_block_start:query_block_end] = new_sum_exp
    
    return output.astype(query.dtype)

# Compare
flash_output = flash_attention(query, key, value, block_size_query=2, block_size_key=2)

print("Naive output:")
print(naive_output.round(4))
print("\nFlash output:")
print(flash_output.round(4))
print("\nMax absolute error:", np.abs(naive_output - flash_output).max())
print("Results match:", np.allclose(naive_output, flash_output, rtol=1e-5))

Naive output:
[[-0.2606  0.1387  0.187   0.178 ]
 [-0.2657  0.1958  0.3021  0.1886]
 [-0.2045  0.1917 -0.1595 -0.1489]
 [ 0.129  -0.2446  0.2879  0.6396]
 [-0.2408  0.0981 -0.2086 -0.4503]
 [ 0.058   0.6045 -0.2326  0.2756]
 [-0.1434 -0.0073 -0.0186 -0.1872]
 [-0.1516 -0.0692  0.1787 -0.2666]]

Flash output:
[[-0.2606  0.1387  0.187   0.178 ]
 [-0.2657  0.1958  0.3021  0.1886]
 [-0.2045  0.1917 -0.1595 -0.1489]
 [ 0.129  -0.2446  0.2879  0.6396]
 [-0.2408  0.0981 -0.2086 -0.4503]
 [ 0.058   0.6045 -0.2326  0.2756]
 [-0.1434 -0.0073 -0.0186 -0.1872]
 [-0.1516 -0.0692  0.1787 -0.2666]]

Max absolute error: 5.9604645e-08
Results match: True


In [18]:
# Test with larger dimensions and various block sizes
np.random.seed(123)
large_sequence_length = 128
large_head_dimension = 64

large_query = np.random.randn(large_sequence_length, large_head_dimension).astype(np.float32)
large_key = np.random.randn(large_sequence_length, large_head_dimension).astype(np.float32)
large_value = np.random.randn(large_sequence_length, large_head_dimension).astype(np.float32)

large_naive_output, _ = naive_attention(large_query, large_key, large_value)

for block_size in [8, 16, 32, 64]:
    large_flash_output = flash_attention(
        large_query, large_key, large_value,
        block_size_query=block_size, 
        block_size_key=block_size
    )
    max_error = np.abs(large_naive_output - large_flash_output).max()
    print(f"Block size {block_size:2d}: max error = {max_error:.2e}")

Block size  8: max error = 3.13e-07
Block size 16: max error = 3.13e-07
Block size 32: max error = 3.13e-07
Block size 64: max error = 3.13e-07


In [20]:
# =============================================================================
# WHY CAN'T WE JUST ACCUMULATE exp(scores) NAIVELY?
# =============================================================================

# The problem: numerical overflow/underflow

# Imagine we have scores across two blocks:
block_1_scores = np.array([1.0, 2.0])
block_2_scores = np.array([100.0, 101.0])
all_scores = np.concatenate([block_1_scores, block_2_scores])

print("All scores:", all_scores)
print()

# NAIVE APPROACH: just accumulate exp(scores)
print("=== NAIVE (broken) ===")
print(f"exp(block_1) = {np.exp(block_1_scores)}")
print(f"exp(block_2) = {np.exp(block_2_scores)}")  # OVERFLOW!
print("^ See the problem? exp(100) and exp(101) overflow to inf!")
print()

# This is why standard softmax subtracts the max first
print("=== STANDARD STABLE SOFTMAX (needs global max) ===")
global_max = all_scores.max()
print(f"Global max = {global_max}")
stable_exp = np.exp(all_scores - global_max)
print(f"exp(scores - global_max) = {stable_exp}")
print(f"sum = {stable_exp.sum()}")
print(f"softmax = {stable_exp / stable_exp.sum()}")

All scores: [  1.   2. 100. 101.]

=== NAIVE (broken) ===
exp(block_1) = [2.71828183 7.3890561 ]
exp(block_2) = [2.68811714e+43 7.30705998e+43]
^ See the problem? exp(100) and exp(101) overflow to inf!

=== STANDARD STABLE SOFTMAX (needs global max) ===
Global max = 101.0
exp(scores - global_max) = [3.72007598e-44 1.01122149e-43 3.67879441e-01 1.00000000e+00]
sum = 1.3678794411714423
softmax = [2.71959346e-44 7.39262147e-44 2.68941421e-01 7.31058579e-01]


In [21]:
# =============================================================================
# OK, SO USE LOCAL MAX... BUT CAN WE JUST ADD THE SUMS?
# =============================================================================

print("=== USING LOCAL MAX PER BLOCK ===")
print()

# Block 1: use local max = 2
block_1_max = block_1_scores.max()
block_1_exp = np.exp(block_1_scores - block_1_max)
block_1_sum = block_1_exp.sum()
print(f"Block 1: scores={block_1_scores}, max={block_1_max}")
print(f"         exp(scores - max) = {block_1_exp}")
print(f"         sum_exp = {block_1_sum:.4f}")
print()

# Block 2: use local max = 101  
block_2_max = block_2_scores.max()
block_2_exp = np.exp(block_2_scores - block_2_max)
block_2_sum = block_2_exp.sum()
print(f"Block 2: scores={block_2_scores}, max={block_2_max}")
print(f"         exp(scores - max) = {block_2_exp}")
print(f"         sum_exp = {block_2_sum:.4f}")
print()

# WRONG: Can we just add them?
naive_total = block_1_sum + block_2_sum
print(f"WRONG: block_1_sum + block_2_sum = {naive_total:.4f}")
print()

# These sums are in DIFFERENT COORDINATE SYSTEMS!
# block_1_sum is relative to max=2
# block_2_sum is relative to max=101
# You can't add apples and oranges!

print("=== THE PROBLEM: DIFFERENT COORDINATE SYSTEMS ===")
print(f"block_1_sum = sum(exp(scores - 2))   <-- relative to max=2")
print(f"block_2_sum = sum(exp(scores - 101)) <-- relative to max=101")
print("These are in different 'units' - can't add directly!")

=== USING LOCAL MAX PER BLOCK ===

Block 1: scores=[1. 2.], max=2.0
         exp(scores - max) = [0.36787944 1.        ]
         sum_exp = 1.3679

Block 2: scores=[100. 101.], max=101.0
         exp(scores - max) = [0.36787944 1.        ]
         sum_exp = 1.3679

WRONG: block_1_sum + block_2_sum = 2.7358

=== THE PROBLEM: DIFFERENT COORDINATE SYSTEMS ===
block_1_sum = sum(exp(scores - 2))   <-- relative to max=2
block_2_sum = sum(exp(scores - 101)) <-- relative to max=101
These are in different 'units' - can't add directly!


In [22]:
# =============================================================================
# THE FIX: CONVERT TO A COMMON COORDINATE SYSTEM
# =============================================================================

print("=== CONVERTING TO COMMON COORDINATES ===")
print()

# To add sums, we need them in the SAME coordinate system (same reference max)
# Choose the global max = 101 as our reference

global_max = max(block_1_max, block_2_max)
print(f"Global max (our reference) = {global_max}")
print()

# Convert block_1_sum from "relative to max=2" to "relative to max=101"
# 
# Originally: block_1_sum = exp(1-2) + exp(2-2) = exp(-1) + exp(0)
# We want:    block_1_sum = exp(1-101) + exp(2-101) = exp(-100) + exp(-99)
#
# Notice: exp(1-101) = exp(1-2) * exp(2-101) = exp(1-2) * exp(-99)
#         exp(2-101) = exp(2-2) * exp(2-101) = exp(2-2) * exp(-99)
#
# So: new_sum = old_sum * exp(old_max - new_max)

correction_factor_1 = np.exp(block_1_max - global_max)  # exp(2 - 101) = exp(-99)
block_1_sum_corrected = block_1_sum * correction_factor_1

print(f"Block 1 correction: exp({block_1_max} - {global_max}) = exp(-99) = {correction_factor_1:.2e}")
print(f"Block 1 sum corrected: {block_1_sum:.4f} * {correction_factor_1:.2e} = {block_1_sum_corrected:.2e}")
print()

# Block 2 is already in the right coordinate system (its max IS the global max)
correction_factor_2 = np.exp(block_2_max - global_max)  # exp(101 - 101) = exp(0) = 1
block_2_sum_corrected = block_2_sum * correction_factor_2

print(f"Block 2 correction: exp({block_2_max} - {global_max}) = exp(0) = {correction_factor_2}")
print(f"Block 2 sum corrected: {block_2_sum:.4f} * {correction_factor_2} = {block_2_sum_corrected:.4f}")
print()

# NOW we can add them!
correct_total = block_1_sum_corrected + block_2_sum_corrected
print(f"CORRECT total sum_exp = {correct_total:.4f}")
print()

# Verify against ground truth
print("=== VERIFICATION ===")
ground_truth_sum = np.exp(all_scores - global_max).sum()
print(f"Ground truth sum_exp = {ground_truth_sum:.4f}")
print(f"Match: {np.isclose(correct_total, ground_truth_sum)}")

=== CONVERTING TO COMMON COORDINATES ===

Global max (our reference) = 101.0

Block 1 correction: exp(2.0 - 101.0) = exp(-99) = 1.01e-43
Block 1 sum corrected: 1.3679 * 1.01e-43 = 1.38e-43

Block 2 correction: exp(101.0 - 101.0) = exp(0) = 1.0
Block 2 sum corrected: 1.3679 * 1.0 = 1.3679

CORRECT total sum_exp = 1.3679

=== VERIFICATION ===
Ground truth sum_exp = 1.3679
Match: True


In [23]:
# =============================================================================
# THE INTUITION
# =============================================================================

print("""
THE KEY INSIGHT:
================

When you compute exp(score - local_max), you're working in a "coordinate system"
centered at that local_max.

   exp(score - max_A)  is in "coordinate system A"
   exp(score - max_B)  is in "coordinate system B"

You CANNOT add values from different coordinate systems!

To convert from system A to system B:
   
   value_in_B = value_in_A * exp(max_A - max_B)

This is like converting currencies:
   - Block 1 computed sums in "max=2 dollars"  
   - Block 2 computed sums in "max=101 dollars"
   - To add them, convert block 1 to "max=101 dollars" first
   - The exchange rate is exp(2 - 101) = exp(-99) ≈ 0

In this example, block 1's contribution becomes essentially ZERO because
its scores (1, 2) are so much smaller than block 2's scores (100, 101).

This is mathematically correct! In the final softmax:
   - scores [1, 2] get probability ≈ 0  
   - scores [100, 101] get probability ≈ [0.27, 0.73]
""")


THE KEY INSIGHT:

When you compute exp(score - local_max), you're working in a "coordinate system"
centered at that local_max.

   exp(score - max_A)  is in "coordinate system A"
   exp(score - max_B)  is in "coordinate system B"

You CANNOT add values from different coordinate systems!

To convert from system A to system B:

   value_in_B = value_in_A * exp(max_A - max_B)

This is like converting currencies:
   - Block 1 computed sums in "max=2 dollars"  
   - Block 2 computed sums in "max=101 dollars"
   - To add them, convert block 1 to "max=101 dollars" first
   - The exchange rate is exp(2 - 101) = exp(-99) ≈ 0

In this example, block 1's contribution becomes essentially ZERO because
its scores (1, 2) are so much smaller than block 2's scores (100, 101).

This is mathematically correct! In the final softmax:
   - scores [1, 2] get probability ≈ 0  
   - scores [100, 101] get probability ≈ [0.27, 0.73]

