# Week 3, Day 3: Stable Softmax

**Time:** ~1 hour

**Goal:** Implement numerically stable softmax using the max-subtraction trick.

## The Challenge

Yesterday we saw softmax explode with `nan`. Today we fix it with a simple mathematical trick that doesn't change the answer but prevents overflow.

**The insight:** softmax(x) = softmax(x - c) for any constant c.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import time

np.set_printoptions(precision=6, suppress=True)
torch.set_printoptions(precision=6, sci_mode=False)

---
## Step 1: The Challenge (5 min)

Recall our broken softmax:

In [None]:
def naive_softmax(x):
    """Broken softmax that overflows."""
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# This fails
scores = np.array([1.0, 100.0, 2.0, 3.0])
print(f"Naive softmax of {scores}:")
print(naive_softmax(scores))

**Goal:** Make this work without changing the mathematical result.

---
## Step 2: Explore — The Invariance Property (15 min)

### The Key Mathematical Insight

For any constant $c$:

$$\text{softmax}(x_i - c) = \frac{e^{x_i - c}}{\sum_j e^{x_j - c}} = \frac{e^{x_i} \cdot e^{-c}}{\sum_j e^{x_j} \cdot e^{-c}} = \frac{e^{x_i}}{\sum_j e^{x_j}} = \text{softmax}(x_i)$$

The $e^{-c}$ cancels out! This is called **shift invariance**.

In [None]:
# Demonstrate shift invariance
x = np.array([1.0, 2.0, 3.0, 4.0])

# Original softmax
probs_original = naive_softmax(x)

# Shifted by various constants
for c in [0, 5, -3, 100, -100]:
    probs_shifted = naive_softmax(x - c)
    max_diff = np.abs(probs_original - probs_shifted).max()
    print(f"c={c:4d}: probs={probs_shifted}, max_diff={max_diff:.2e}")

### The Optimal Shift: max(x)

If we subtract $\max(x)$:
- The largest value becomes 0
- All other values become negative
- $e^{\text{negative}}$ is always in $(0, 1]$ — no overflow!

In [None]:
def stable_softmax(x):
    """Numerically stable softmax."""
    # Subtract max for numerical stability
    x_shifted = x - np.max(x, axis=-1, keepdims=True)
    exp_x = np.exp(x_shifted)
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# Now it works!
scores = np.array([1.0, 100.0, 2.0, 3.0])
print(f"Stable softmax of {scores}:")
print(stable_softmax(scores))

# Even extreme values work
extreme_scores = np.array([1.0, 1000.0, 2.0, 3.0])
print(f"\nStable softmax of {extreme_scores}:")
print(stable_softmax(extreme_scores))

In [None]:
# Step-by-step visualization
def visualize_stable_softmax(x):
    """Show each step of stable softmax."""
    print("=" * 60)
    print(f"Input: {x}")
    
    # Step 1: Find max
    max_x = np.max(x)
    print(f"\nStep 1: max(x) = {max_x}")
    
    # Step 2: Subtract max
    x_shifted = x - max_x
    print(f"Step 2: x - max = {x_shifted}")
    
    # Step 3: Compute exp
    exp_x = np.exp(x_shifted)
    print(f"Step 3: exp(x - max) = {exp_x}")
    
    # Step 4: Sum
    sum_exp = np.sum(exp_x)
    print(f"Step 4: sum = {sum_exp}")
    
    # Step 5: Divide
    probs = exp_x / sum_exp
    print(f"Step 5: probs = {probs}")
    print(f"        sum(probs) = {np.sum(probs)}")
    
    return probs

visualize_stable_softmax(np.array([1.0, 100.0, 2.0, 3.0]))

### What About Underflow?

After subtracting max, small values become very negative. $e^{-1000} \approx 0$ — this is **underflow**.

But underflow is benign! A probability of 0 (or very close) is mathematically correct for tokens that are irrelevant.

In [None]:
# Underflow example
scores = np.array([0.0, 100.0, -500.0, -1000.0])
probs = stable_softmax(scores)

print(f"Scores: {scores}")
print(f"After shifting: {scores - scores.max()}")
print(f"Probabilities: {probs}")
print(f"\nThe -1000 score gets probability ≈ 0, which is correct!")

---
## Step 3: The Concept — Mathematical Proof (10 min)

Let's prove the shift invariance formally.

### Theorem: softmax(x - c) = softmax(x)

**Proof:**

$$\text{softmax}(x_i - c) = \frac{e^{x_i - c}}{\sum_{j=1}^{n} e^{x_j - c}}$$

Using the property $e^{a-b} = e^a \cdot e^{-b}$:

$$= \frac{e^{x_i} \cdot e^{-c}}{\sum_{j=1}^{n} e^{x_j} \cdot e^{-c}}$$

Factor out $e^{-c}$ from the sum:

$$= \frac{e^{x_i} \cdot e^{-c}}{e^{-c} \cdot \sum_{j=1}^{n} e^{x_j}}$$

Cancel $e^{-c}$:

$$= \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}} = \text{softmax}(x_i) \quad \square$$

In [None]:
# Numerical verification of the proof
np.random.seed(42)
x = np.random.randn(1000)  # 1000 random values

# Test with many different shift values
probs_original = stable_softmax(x)

max_errors = []
for c in np.linspace(-1000, 1000, 100):
    probs_shifted = stable_softmax(x - c)
    max_error = np.abs(probs_original - probs_shifted).max()
    max_errors.append(max_error)

print(f"Max error across 100 different shifts: {max(max_errors):.2e}")
print("(Should be essentially 0, limited by floating point precision)")

### Log-Sum-Exp Trick

A related technique is computing log(softmax) stably:

$$\log(\text{softmax}(x_i)) = x_i - \log\sum_j e^{x_j}$$

The log-sum-exp (LSE) can be computed stably:

$$\text{LSE}(x) = \max(x) + \log\sum_j e^{x_j - \max(x)}$$

In [None]:
def log_sum_exp(x):
    """Numerically stable log-sum-exp."""
    max_x = np.max(x, axis=-1, keepdims=True)
    return max_x + np.log(np.sum(np.exp(x - max_x), axis=-1, keepdims=True))

def log_softmax(x):
    """Numerically stable log-softmax."""
    return x - log_sum_exp(x)

# Compare
x = np.array([1.0, 100.0, 2.0, 3.0])

# Method 1: log of stable softmax
log_probs_v1 = np.log(stable_softmax(x))

# Method 2: log-softmax directly
log_probs_v2 = log_softmax(x)

print(f"log(stable_softmax): {log_probs_v1.flatten()}")
print(f"log_softmax:         {log_probs_v2.flatten()}")
print(f"Max difference: {np.abs(log_probs_v1 - log_probs_v2).max():.2e}")

---
## Step 4: Code It — Efficient Implementation (30 min)

### PyTorch Comparison

Let's verify our implementation matches PyTorch's built-in softmax.

In [None]:
def stable_softmax_torch(x):
    """Our implementation in PyTorch."""
    max_x = x.max(dim=-1, keepdim=True).values
    exp_x = torch.exp(x - max_x)
    return exp_x / exp_x.sum(dim=-1, keepdim=True)

# Test against PyTorch's implementation
x = torch.randn(100, 50)

our_result = stable_softmax_torch(x)
pytorch_result = torch.softmax(x, dim=-1)

max_diff = (our_result - pytorch_result).abs().max()
print(f"Max difference from torch.softmax: {max_diff:.2e}")

# Verify properties
print(f"All probabilities positive: {(our_result >= 0).all()}")
print(f"Rows sum to 1: {our_result.sum(dim=-1).allclose(torch.ones(100))}")

### Triton Implementation

For GPU, we implement softmax in Triton. The key is computing max and sum in a single pass through memory.

In [None]:
import triton
import triton.language as tl

@triton.jit
def softmax_kernel(
    input_ptr, output_ptr,
    n_cols,
    input_row_stride, output_row_stride,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Stable softmax kernel.
    Each program handles one row of the input.
    """
    # Get row index
    row_idx = tl.program_id(0)
    
    # Pointers to the start of this row
    row_start_ptr = input_ptr + row_idx * input_row_stride
    
    # Load the row in chunks of BLOCK_SIZE
    col_offsets = tl.arange(0, BLOCK_SIZE)
    
    # First pass: find max
    row_max = float('-inf')
    for start in range(0, n_cols, BLOCK_SIZE):
        offsets = start + col_offsets
        mask = offsets < n_cols
        row_vals = tl.load(row_start_ptr + offsets, mask=mask, other=float('-inf'))
        row_max = tl.maximum(row_max, tl.max(row_vals, axis=0))
    
    # Second pass: compute exp(x - max) and sum
    row_sum = 0.0
    for start in range(0, n_cols, BLOCK_SIZE):
        offsets = start + col_offsets
        mask = offsets < n_cols
        row_vals = tl.load(row_start_ptr + offsets, mask=mask, other=float('-inf'))
        row_vals = tl.exp(row_vals - row_max)
        row_sum += tl.sum(row_vals, axis=0)
    
    # Third pass: normalize and store
    out_row_start_ptr = output_ptr + row_idx * output_row_stride
    for start in range(0, n_cols, BLOCK_SIZE):
        offsets = start + col_offsets
        mask = offsets < n_cols
        row_vals = tl.load(row_start_ptr + offsets, mask=mask, other=float('-inf'))
        softmax_vals = tl.exp(row_vals - row_max) / row_sum
        tl.store(out_row_start_ptr + offsets, softmax_vals, mask=mask)

In [None]:
def triton_softmax(x):
    """Wrapper for Triton softmax kernel."""
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)
    
    # Choose BLOCK_SIZE based on n_cols
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    BLOCK_SIZE = min(BLOCK_SIZE, 1024)  # Cap at 1024
    
    # Launch kernel with one program per row
    softmax_kernel[(n_rows,)](
        x, output,
        n_cols,
        x.stride(0), output.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
    )
    
    return output

# Test on GPU
if torch.cuda.is_available():
    x_gpu = torch.randn(128, 256, device='cuda')
    
    triton_result = triton_softmax(x_gpu)
    pytorch_result = torch.softmax(x_gpu, dim=-1)
    
    max_diff = (triton_result - pytorch_result).abs().max()
    print(f"Max diff from PyTorch: {max_diff:.2e}")
    print(f"Rows sum to 1: {triton_result.sum(dim=-1).allclose(torch.ones(128, device='cuda'))}")
else:
    print("GPU not available. Triton test skipped.")

### Exercise: Two-Pass vs Three-Pass

Our kernel makes 3 passes over the data. Can we do better?

**Challenge:** Implement a 2-pass softmax that computes max and sum together, then normalizes.

In [None]:
@triton.jit
def softmax_kernel_2pass(
    input_ptr, output_ptr,
    n_cols,
    input_row_stride, output_row_stride,
    BLOCK_SIZE: tl.constexpr,
):
    """
    Two-pass softmax: 
    Pass 1: Compute max AND partial exp-sums
    Pass 2: Normalize
    
    Uses online algorithm to update sum when max changes.
    """
    row_idx = tl.program_id(0)
    row_start_ptr = input_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    
    # Pass 1: Online max and sum computation
    running_max = float('-inf')
    running_sum = 0.0
    
    for start in range(0, n_cols, BLOCK_SIZE):
        offsets = start + col_offsets
        mask = offsets < n_cols
        row_vals = tl.load(row_start_ptr + offsets, mask=mask, other=float('-inf'))
        
        # Find block max
        block_max = tl.max(row_vals, axis=0)
        
        # Update running max and rescale running sum
        new_max = tl.maximum(running_max, block_max)
        
        # Rescale old sum to new max
        running_sum = running_sum * tl.exp(running_max - new_max)
        
        # Add new block contribution
        running_sum += tl.sum(tl.where(mask, tl.exp(row_vals - new_max), 0.0), axis=0)
        
        running_max = new_max
    
    # Pass 2: Normalize and store
    out_row_start_ptr = output_ptr + row_idx * output_row_stride
    for start in range(0, n_cols, BLOCK_SIZE):
        offsets = start + col_offsets
        mask = offsets < n_cols
        row_vals = tl.load(row_start_ptr + offsets, mask=mask, other=float('-inf'))
        softmax_vals = tl.exp(row_vals - running_max) / running_sum
        tl.store(out_row_start_ptr + offsets, softmax_vals, mask=mask)

def triton_softmax_2pass(x):
    """Wrapper for 2-pass Triton softmax."""
    n_rows, n_cols = x.shape
    output = torch.empty_like(x)
    
    BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 1024)
    
    softmax_kernel_2pass[(n_rows,)](
        x, output,
        n_cols,
        x.stride(0), output.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
    )
    
    return output

# Test
if torch.cuda.is_available():
    x_gpu = torch.randn(128, 256, device='cuda')
    
    result_2pass = triton_softmax_2pass(x_gpu)
    result_pytorch = torch.softmax(x_gpu, dim=-1)
    
    print(f"2-pass max diff: {(result_2pass - result_pytorch).abs().max():.2e}")
else:
    print("GPU not available.")

### Benchmark

In [None]:
if torch.cuda.is_available():
    # Benchmark different implementations
    sizes = [(128, 128), (256, 512), (512, 1024), (1024, 2048)]
    
    print(f"{'Size':>15} {'PyTorch':>12} {'Triton 3-pass':>15} {'Triton 2-pass':>15}")
    print("-" * 60)
    
    for n_rows, n_cols in sizes:
        x = torch.randn(n_rows, n_cols, device='cuda')
        
        # Warmup
        for _ in range(10):
            _ = torch.softmax(x, dim=-1)
            _ = triton_softmax(x)
            _ = triton_softmax_2pass(x)
        
        torch.cuda.synchronize()
        
        # PyTorch
        start = time.perf_counter()
        for _ in range(100):
            _ = torch.softmax(x, dim=-1)
        torch.cuda.synchronize()
        pytorch_time = (time.perf_counter() - start) / 100 * 1000
        
        # Triton 3-pass
        start = time.perf_counter()
        for _ in range(100):
            _ = triton_softmax(x)
        torch.cuda.synchronize()
        triton3_time = (time.perf_counter() - start) / 100 * 1000
        
        # Triton 2-pass
        start = time.perf_counter()
        for _ in range(100):
            _ = triton_softmax_2pass(x)
        torch.cuda.synchronize()
        triton2_time = (time.perf_counter() - start) / 100 * 1000
        
        print(f"{n_rows}x{n_cols:>4}: {pytorch_time:>10.3f}ms {triton3_time:>13.3f}ms {triton2_time:>13.3f}ms")
else:
    print("GPU not available for benchmarking.")

---
## Step 5: Verify — Quiz & Reflection (10 min)

### Quiz

In [None]:
def check_answer(question, your_answer, correct_answer):
    if your_answer == correct_answer:
        print(f"✓ Correct! {question}")
    else:
        print(f"✗ Incorrect. {question}")
        print(f"  Your answer: {your_answer}, Correct: {correct_answer}")

# Q1: What is the key property that makes stable softmax work?
# a) exp(x) is always positive
# b) softmax(x - c) = softmax(x) for any constant c
# c) max(x) is always finite
# d) Division distributes over addition
q1_answer = 'b'
check_answer("Key property", q1_answer, 'b')

In [None]:
# Q2: After subtracting max(x), the largest value in x becomes:
# a) 1
# b) max(x)
# c) 0
# d) -max(x)
q2_answer = 'c'
check_answer("Largest value after shift", q2_answer, 'c')

In [None]:
# Q3: stable_softmax([0, -1000, -2000]) produces approximately:
# a) [1/3, 1/3, 1/3]
# b) [1, 0, 0]
# c) [nan, nan, nan]
# d) [0.33, 0.33, 0.33]

# Let's verify:
result = stable_softmax(np.array([0.0, -1000.0, -2000.0]))
print(f"stable_softmax([0, -1000, -2000]) = {result}")

q3_answer = 'b'
check_answer("Softmax of [0, -1000, -2000]", q3_answer, 'b')

In [None]:
# Q4: How many passes over the data does the 2-pass algorithm need?
# a) 1 (compute everything in one go)
# b) 2 (online max/sum, then normalize)
# c) 3 (max, exp-sum, normalize)
# d) It depends on the input size
q4_answer = 'b'
check_answer("Number of passes in 2-pass algorithm", q4_answer, 'b')

### Reflection Questions

1. **Memory bandwidth:** Why does reducing passes matter for GPU performance?

2. **The online algorithm:** How does it update the sum when a new max is found? (Hint: multiply by exp(old_max - new_max))

3. **Can we do 1-pass?** What information would we need to store to normalize in a single pass?

---

## Summary

| Technique | Key Insight |
|-----------|------------|
| Max subtraction | softmax(x - max) = softmax(x), prevents overflow |
| Log-sum-exp | LSE(x) = max(x) + log(Σexp(x - max)), stable log computation |
| Online algorithm | Update running sum when max changes: sum *= exp(old_max - new_max) |

**Tomorrow:** We'll apply stable softmax to full attention computation and discover the quadratic memory problem.

---

**Interactive Reference:** [attention-math.html](../attention-math.html) Section 3 — Online Softmax Simulation