# Week 3, Day 2: The Softmax Problem

**Time:** ~1 hour

**Goal:** Understand why naive softmax fails catastrophically and see the overflow in action.

## The Challenge

Yesterday we computed attention scores with dot products. Now we need to convert them to **probabilities** that sum to 1.

The tool: **softmax**. The problem: **your kernel will crash**.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import warnings

np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4, sci_mode=False)

---
## Step 1: The Challenge — From Scores to Probabilities (5 min)

Softmax converts arbitrary real numbers into a probability distribution:

$$\text{softmax}(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$$

Properties:
- All outputs are positive (because $e^x > 0$)
- Outputs sum to 1
- Larger inputs get exponentially larger shares

In [None]:
def naive_softmax(x):
    """The obvious softmax implementation. What could go wrong?"""
    exp_x = np.exp(x)
    return exp_x / np.sum(exp_x)

# Works fine for small values
scores_small = np.array([1.0, 2.0, 3.0, 4.0])
probs = naive_softmax(scores_small)
print(f"Scores: {scores_small}")
print(f"Probabilities: {probs}")
print(f"Sum: {probs.sum():.6f}")

In [None]:
# Visualize softmax behavior
def plot_softmax_effect(scores, title="Softmax Transformation"):
    probs = naive_softmax(scores)
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    
    # Raw scores
    axes[0].bar(range(len(scores)), scores, color='steelblue')
    axes[0].set_title('Raw Scores')
    axes[0].set_xlabel('Index')
    axes[0].set_ylabel('Score')
    axes[0].axhline(y=0, color='k', linewidth=0.5)
    
    # Probabilities
    axes[1].bar(range(len(probs)), probs, color='coral')
    axes[1].set_title('After Softmax')
    axes[1].set_xlabel('Index')
    axes[1].set_ylabel('Probability')
    axes[1].set_ylim(0, 1)
    
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

plot_softmax_effect(np.array([1.0, 2.0, 3.0, 0.5]), "Normal Scores")

### The Exponential Amplification

Softmax doesn't just normalize — it **amplifies** differences exponentially.

A score of 10 vs 5 becomes $e^{10}/e^5 \approx 148×$ more weight, not $2×$.

In [None]:
# Amplification effect
scores = np.array([5.0, 10.0, 5.0, 5.0])
probs = naive_softmax(scores)

print(f"Scores: {scores}")
print(f"Score ratio (10 vs 5): {10/5:.1f}x")
print(f"Probability ratio: {probs[1]/probs[0]:.1f}x")

plot_softmax_effect(scores, "Exponential Amplification")

---
## Step 2: Explore — When Softmax Breaks (15 min)

Now let's see what happens with larger values.

In [None]:
# What happens as scores get larger?
for max_val in [10, 50, 100, 500, 1000]:
    scores = np.array([0.0, max_val, 0.0, 0.0])
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        probs = naive_softmax(scores)
    
    print(f"max_score={max_val:4d}: exp({max_val})={np.exp(max_val):.2e}, probs={probs}")

**Observation:** At `max_score=1000`, we get `nan` (Not a Number) because:

1. `exp(1000)` = $2.7 \times 10^{434}$ — way bigger than `float64` can represent (~$10^{308}$)
2. We get `inf` (infinity)
3. `inf / inf = nan`

### The FP16 Disaster

In ML, we often use FP16 (16-bit floats) for speed. The problem is much worse there:

In [None]:
# FP16 limits
print("Float type limits:")
print(f"FP16 max: {np.finfo(np.float16).max:.2e}")
print(f"FP32 max: {np.finfo(np.float32).max:.2e}")
print(f"FP64 max: {np.finfo(np.float64).max:.2e}")

print(f"\nexp() overflow threshold:")
print(f"FP16: exp({np.log(65504):.1f}) = overflow")
print(f"FP32: exp({np.log(3.4e38):.1f}) = overflow")

In [None]:
# FP16 softmax failure
def naive_softmax_fp16(x):
    """Softmax in FP16 — watch it fail."""
    x_fp16 = x.astype(np.float16)
    exp_x = np.exp(x_fp16)
    return exp_x / np.sum(exp_x)

# In FP16, we overflow much sooner
for max_val in [5, 10, 11, 12, 15, 20]:
    scores = np.array([0.0, max_val, 0.0, 0.0])
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        probs = naive_softmax_fp16(scores)
    
    exp_val = np.exp(np.float16(max_val))
    print(f"max_score={max_val:2d}: exp({max_val})={float(exp_val):.2e}, probs={probs}")

**FP16 breaks at just max_score=12!**

In attention, dot product scores can easily exceed 12. This is a critical problem.

### Realistic Attention Scores

Let's see what actual attention scores look like:

In [None]:
# Simulate realistic attention scores
seq_len = 128
d_model = 64

# Random Q and K (unit variance)
Q = np.random.randn(seq_len, d_model)
K = np.random.randn(seq_len, d_model)

# Attention scores (unscaled)
scores_unscaled = Q @ K.T

# Scaled by sqrt(d)
scores_scaled = scores_unscaled / np.sqrt(d_model)

print("Unscaled scores:")
print(f"  Min: {scores_unscaled.min():.2f}")
print(f"  Max: {scores_unscaled.max():.2f}")
print(f"  Std: {scores_unscaled.std():.2f}")

print("\nScaled scores (÷√d):")
print(f"  Min: {scores_scaled.min():.2f}")
print(f"  Max: {scores_scaled.max():.2f}")
print(f"  Std: {scores_scaled.std():.2f}")

In [None]:
# Distribution of attention scores
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].hist(scores_unscaled.flatten(), bins=50, color='steelblue', edgecolor='black')
axes[0].axvline(x=11, color='red', linestyle='--', label='FP16 exp() overflow')
axes[0].set_title(f'Unscaled QK^T Scores\n(d={d_model})')
axes[0].set_xlabel('Score')
axes[0].legend()

axes[1].hist(scores_scaled.flatten(), bins=50, color='coral', edgecolor='black')
axes[1].axvline(x=11, color='red', linestyle='--', label='FP16 exp() overflow')
axes[1].set_title('Scaled QK^T / √d Scores')
axes[1].set_xlabel('Score')
axes[1].legend()

plt.tight_layout()
plt.show()

### The Pattern of Failure

1. **Long sequences** → more chances for a high dot product
2. **Larger d_model** → scores have higher variance (before scaling)
3. **Correlated embeddings** → systematic high scores
4. **Layer norm effects** → can create outliers

In real models, overflow is **guaranteed** without proper handling.

In [None]:
# Demonstrate failure in a realistic setting
def attention_forward_naive(Q, K, V):
    """Naive attention implementation (will fail with large scores)."""
    d = Q.shape[-1]
    
    # Compute scores
    scores = Q @ K.T / np.sqrt(d)
    
    # Softmax (naive)
    exp_scores = np.exp(scores)
    attention_weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)
    
    # Weighted sum of values
    output = attention_weights @ V
    
    return output, attention_weights

# Create inputs that will cause problems
np.random.seed(42)

# Add some outliers (common in real embeddings)
Q = np.random.randn(32, 64)
K = np.random.randn(32, 64)
V = np.random.randn(32, 64)

# Make some Q-K pairs very similar (high attention)
K[5] = Q[10] * 3  # This will create a score of ~3*64 = 192 (unscaled)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    output, weights = attention_forward_naive(Q, K, V)

print(f"Output contains NaN: {np.isnan(output).any()}")
print(f"Weights contain NaN: {np.isnan(weights).any()}")
print(f"\nMax score: {(Q @ K.T / np.sqrt(64)).max():.2f}")

---
## Step 3: The Concept — Why exp() Explodes (10 min)

### Floating Point Representation

A floating point number is stored as:

$$\text{value} = (-1)^s \times 2^{e-\text{bias}} \times (1 + m)$$

where:
- $s$ = sign bit (0 or 1)
- $e$ = exponent bits
- $m$ = mantissa (fractional bits)
- bias = offset to allow negative exponents

| Format | Sign | Exponent | Mantissa | Max Value | exp() overflow |
|--------|------|----------|----------|-----------|----------------|
| FP16 | 1 | 5 | 10 | 65,504 | ~11 |
| BF16 | 1 | 8 | 7 | 3.4×10³⁸ | ~88 |
| FP32 | 1 | 8 | 23 | 3.4×10³⁸ | ~88 |

In [None]:
# Visualize the exponential function
x = np.linspace(-5, 15, 1000)
y = np.exp(x)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Linear scale
axes[0].plot(x, y, 'b-', linewidth=2)
axes[0].axhline(y=65504, color='red', linestyle='--', label='FP16 max')
axes[0].axvline(x=11, color='red', linestyle=':', alpha=0.5)
axes[0].set_xlabel('x')
axes[0].set_ylabel('exp(x)')
axes[0].set_title('exp(x) — Linear Scale')
axes[0].legend()
axes[0].set_ylim(0, 100000)
axes[0].grid(True, alpha=0.3)

# Log scale
axes[1].semilogy(x, y, 'b-', linewidth=2)
axes[1].axhline(y=65504, color='red', linestyle='--', label='FP16 max')
axes[1].axhline(y=3.4e38, color='orange', linestyle='--', label='FP32 max')
axes[1].set_xlabel('x')
axes[1].set_ylabel('exp(x) [log scale]')
axes[1].set_title('exp(x) — Log Scale')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### The Inf/NaN Cascade

When softmax fails, here's the sequence:

1. `exp(large_value)` → `inf`
2. `sum([..., inf, ...])` → `inf`
3. `inf / inf` → `nan`
4. `nan` propagates through all subsequent operations
5. Your model outputs garbage

In [None]:
# Trace the failure step by step
scores = np.array([1.0, 100.0, 2.0, 3.0])

print("Step-by-step failure:")
print(f"1. scores = {scores}")

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    exp_scores = np.exp(scores)
    print(f"2. exp(scores) = {exp_scores}")
    
    sum_exp = np.sum(exp_scores)
    print(f"3. sum(exp) = {sum_exp}")
    
    probs = exp_scores / sum_exp
    print(f"4. probs = {probs}")
    print(f"5. sum(probs) = {np.sum(probs)}  # Should be 1.0!")

### What We Need

A softmax implementation that:
1. **Never overflows** — even with large inputs
2. **Preserves accuracy** — produces correct probabilities
3. **Works in FP16** — for efficient GPU computation

Tomorrow's solution: the **max-subtraction trick**.

---
## Step 4: Code It — Document the Problem (30 min)

Let's write code that systematically identifies when softmax will fail.

In [None]:
def analyze_softmax_safety(scores, dtype='float32'):
    """
    Analyze whether softmax will overflow for given scores.
    
    Returns a dict with analysis results.
    """
    if dtype == 'float16':
        max_safe = 11.0  # exp(11) ≈ 59874, close to FP16 max of 65504
        dtype_max = 65504.0
    elif dtype == 'float32':
        max_safe = 88.0  # exp(88) ≈ 1.6e38, close to FP32 max
        dtype_max = 3.4e38
    else:
        max_safe = 709.0  # For float64
        dtype_max = 1.8e308
    
    scores_flat = scores.flatten()
    max_score = scores_flat.max()
    min_score = scores_flat.min()
    
    will_overflow = max_score > max_safe
    
    # Count how many values would cause overflow
    overflow_count = (scores_flat > max_safe).sum()
    
    return {
        'dtype': dtype,
        'max_safe_input': max_safe,
        'dtype_max': dtype_max,
        'scores_min': min_score,
        'scores_max': max_score,
        'will_overflow': will_overflow,
        'overflow_count': overflow_count,
        'total_values': len(scores_flat),
    }

# Test with various score distributions
test_cases = [
    ('Normal scores', np.random.randn(100, 100)),
    ('Large d_model', np.random.randn(100, 100) * np.sqrt(512)),
    ('With outliers', np.concatenate([np.random.randn(99, 100), np.ones((1, 100)) * 50])),
]

for name, scores in test_cases:
    for dtype in ['float16', 'float32']:
        result = analyze_softmax_safety(scores, dtype)
        status = "FAIL" if result['will_overflow'] else "OK"
        print(f"{name} ({dtype}): max={result['scores_max']:.1f}, safe<{result['max_safe_input']:.0f} → {status}")

In [None]:
def softmax_with_overflow_detection(x, dtype='float32'):
    """
    Softmax that reports if overflow occurred.
    
    Returns: (probabilities, overflow_detected)
    """
    # Check safety first
    analysis = analyze_softmax_safety(x, dtype)
    
    # Compute softmax
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        exp_x = np.exp(x)
        sum_exp = exp_x.sum(axis=-1, keepdims=True)
        probs = exp_x / sum_exp
    
    # Check for actual overflow
    has_inf = np.isinf(exp_x).any()
    has_nan = np.isnan(probs).any()
    
    overflow_detected = has_inf or has_nan
    
    return probs, {
        'predicted_overflow': analysis['will_overflow'],
        'actual_overflow': overflow_detected,
        'has_inf': has_inf,
        'has_nan': has_nan,
    }

# Demo
scores_safe = np.array([1.0, 2.0, 3.0, 4.0])
scores_dangerous = np.array([1.0, 100.0, 3.0, 4.0])

probs, info = softmax_with_overflow_detection(scores_safe)
print(f"Safe scores: {info}")

probs, info = softmax_with_overflow_detection(scores_dangerous)
print(f"Dangerous scores: {info}")

### Exercise: Find the Breaking Point

Write code to find the exact threshold where softmax starts producing NaN for different dtypes.

In [None]:
def find_overflow_threshold(dtype_str):
    """
    Binary search to find the exact input value where exp() overflows.
    """
    if dtype_str == 'float16':
        dtype = np.float16
    elif dtype_str == 'float32':
        dtype = np.float32
    else:
        dtype = np.float64
    
    low, high = 0.0, 1000.0
    
    while high - low > 0.01:
        mid = (low + high) / 2
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            result = np.exp(dtype(mid))
        
        if np.isinf(result):
            high = mid
        else:
            low = mid
    
    return low

# Find thresholds
for dtype in ['float16', 'float32', 'float64']:
    threshold = find_overflow_threshold(dtype)
    print(f"{dtype}: exp(x) overflows at x ≈ {threshold:.2f}")

In [None]:
# Exercise: Create a "stress test" for softmax
def softmax_stress_test(seq_lengths, d_models, num_trials=100):
    """
    Test softmax failure rate across different configurations.
    """
    results = []
    
    for seq_len in seq_lengths:
        for d_model in d_models:
            failures = 0
            max_scores = []
            
            for _ in range(num_trials):
                # Generate random Q, K
                Q = np.random.randn(seq_len, d_model)
                K = np.random.randn(seq_len, d_model)
                
                # Compute scores (UNSCALED - to show the problem)
                scores = Q @ K.T
                max_scores.append(scores.max())
                
                # Check if softmax would fail in FP16
                if scores.max() > 11:
                    failures += 1
            
            results.append({
                'seq_len': seq_len,
                'd_model': d_model,
                'failure_rate': failures / num_trials,
                'avg_max_score': np.mean(max_scores),
            })
    
    return results

# Run stress test
results = softmax_stress_test(
    seq_lengths=[32, 64, 128, 256],
    d_models=[32, 64, 128],
    num_trials=50
)

print("FP16 Softmax Failure Rate (UNSCALED scores):")
print("-" * 60)
for r in results:
    print(f"seq={r['seq_len']:4d}, d={r['d_model']:4d}: "
          f"failure={r['failure_rate']*100:5.1f}%, avg_max={r['avg_max_score']:.1f}")

---
## Step 5: Verify — Quiz & Reflection (10 min)

### Quiz

In [None]:
def check_answer(question, your_answer, correct_answer):
    if your_answer == correct_answer:
        print(f"✓ Correct! {question}")
    else:
        print(f"✗ Incorrect. {question}")
        print(f"  Your answer: {your_answer}, Correct: {correct_answer}")

# Q1: What is exp(0)?
q1_answer = 1  # Your answer
check_answer("exp(0)", q1_answer, 1)

In [None]:
# Q2: At what approximate input value does exp() overflow in FP16?
# a) 5
# b) 11
# c) 88
# d) 709
q2_answer = 'b'  # Your answer
check_answer("FP16 exp() overflow threshold", q2_answer, 'b')

In [None]:
# Q3: If you have scores [1, 1000, 1, 1], what will naive_softmax return?
# a) [0, 1, 0, 0]
# b) [0.25, 0.25, 0.25, 0.25]
# c) [nan, nan, nan, nan]
# d) An error will be raised
q3_answer = 'c'  # Your answer
check_answer("Softmax of [1, 1000, 1, 1]", q3_answer, 'c')

In [None]:
# Q4: Why does scaling QK^T by 1/√d help prevent overflow?
# a) It makes the scores smaller
# b) It converts scores to probabilities
# c) It normalizes the variance of scores
# d) Both a and c
q4_answer = 'd'  # Your answer
check_answer("Why scale by 1/√d?", q4_answer, 'd')

### Reflection Questions

1. **Why can't we just use FP64 everywhere?** (Think about memory bandwidth and Tensor Core support.)

2. **The √d scaling helps but doesn't solve the problem.** Can you construct an example where scaled scores still overflow?

3. **What if all scores are very negative (like -1000)?** Does that cause problems?

---

## Summary

| Problem | Cause | Impact |
|---------|-------|--------|
| exp() overflow | Large positive inputs | Returns inf |
| inf/inf | Overflow in numerator and denominator | Returns nan |
| NaN propagation | Any operation with nan produces nan | Model outputs garbage |

**Key numbers to remember:**
- FP16: exp(x) overflows at x ≈ 11
- FP32: exp(x) overflows at x ≈ 88

**Tomorrow:** The solution — stable softmax using the max-subtraction trick.

---

**Interactive Reference:** [attention-math.html](../attention-math.html) Section 2 — Softmax Visualization