# Week 3, Day 5: Online Softmax — The Streaming Algorithm

**Time:** ~1 hour

**Goal:** Understand and implement online softmax, the key algorithm behind FlashAttention.

## The Challenge

Standard softmax needs to see **all values** before computing any probability:
1. Find max(x) — requires seeing all x
2. Compute Σexp(x - max) — requires seeing all x again
3. Divide each exp(x - max) by the sum

**The question:** Can we compute softmax **one block at a time** without storing the entire sequence?

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

np.set_printoptions(precision=6, suppress=True)
torch.set_printoptions(precision=6, sci_mode=False)

---
## Step 1: The Challenge — Streaming Data (5 min)

Imagine processing attention in **blocks**:
- You load a block of K values into fast SRAM
- Compute partial attention
- Load the next block
- **Problem:** When you see a larger value in block 2, you need to update block 1's results!

The **online softmax** algorithm solves this with clever bookkeeping.

In [None]:
# Standard softmax - needs all data at once
def standard_softmax(x):
    """Must see entire array to compute."""
    max_x = x.max()         # Pass 1: find max
    exp_x = np.exp(x - max_x)  # Pass 2: compute exp
    sum_exp = exp_x.sum()   # Pass 3: sum
    return exp_x / sum_exp  # Pass 4: normalize

# What if data comes in blocks?
x = np.array([1.0, 3.0, 2.0, 5.0, 4.0, 6.0, 2.0, 1.0])
print(f"Full data: {x}")
print(f"Standard softmax: {standard_softmax(x)}")
print(f"Sum: {standard_softmax(x).sum():.6f}")

---
## Step 2: Explore — The Rescaling Trick (15 min)

### Key Insight

When we find a new maximum, we can **rescale** our previous results!

If we computed exp(x - old_max), and then find new_max > old_max:

$$\exp(x - \text{new\_max}) = \exp(x - \text{old\_max}) \times \exp(\text{old\_max} - \text{new\_max})$$

The correction factor $\exp(\text{old\_max} - \text{new\_max})$ is just a scalar multiplication!

In [None]:
# Demonstrate the rescaling
x = np.array([1.0, 3.0, 2.0])  # First block
old_max = 3.0

# Compute with old max
exp_old = np.exp(x - old_max)
print(f"Block 1: {x}")
print(f"exp(x - {old_max}) = {exp_old}")

# New block arrives with larger value
new_block = np.array([5.0, 4.0, 6.0])
new_max = 6.0

# We need to rescale block 1's exp values
correction = np.exp(old_max - new_max)  # = exp(3 - 6) = exp(-3)
exp_rescaled = exp_old * correction

print(f"\nNew max found: {new_max}")
print(f"Correction factor: exp({old_max} - {new_max}) = {correction:.6f}")
print(f"Rescaled block 1: {exp_rescaled}")

# Verify: this equals computing directly with new max
exp_direct = np.exp(x - new_max)
print(f"Direct computation: {exp_direct}")
print(f"Match: {np.allclose(exp_rescaled, exp_direct)}")

### The Online Softmax Algorithm

We maintain two running statistics:
- **m**: Running maximum
- **l**: Running sum of exp(x - m)

When processing a new block:
1. Find block's max (m_block)
2. Update global max: m_new = max(m_old, m_block)
3. Rescale old sum: l_old *= exp(m_old - m_new)
4. Add new block: l_new = l_old + sum(exp(block - m_new))

In [None]:
def online_softmax_demo(x, block_size=2):
    """
    Online softmax with detailed logging.
    """
    n = len(x)
    
    # Initialize
    m = float('-inf')  # Running max
    l = 0.0            # Running sum of exp
    
    print("Online Softmax Computation")
    print("=" * 60)
    
    # Process blocks
    for block_start in range(0, n, block_size):
        block = x[block_start:block_start + block_size]
        print(f"\nBlock [{block_start}:{block_start + len(block)}]: {block}")
        
        # Find block max
        m_block = block.max()
        print(f"  Block max: {m_block}")
        
        # Update global max
        m_old = m
        m_new = max(m, m_block)
        print(f"  Global max: {m_old} → {m_new}")
        
        # Rescale old sum
        if m_old != float('-inf'):
            correction = np.exp(m_old - m_new)
            l_old = l
            l = l * correction
            print(f"  Rescale sum: {l_old:.6f} × exp({m_old}-{m_new}) = {l:.6f}")
        
        # Add new block contribution
        block_exp = np.exp(block - m_new)
        block_sum = block_exp.sum()
        l += block_sum
        print(f"  Block exp(x - {m_new}): {block_exp}")
        print(f"  Block sum: {block_sum:.6f}")
        print(f"  Running sum: {l:.6f}")
        
        m = m_new
    
    print(f"\n{'='*60}")
    print(f"Final: max = {m}, sum = {l:.6f}")
    
    # Compute final softmax
    probs = np.exp(x - m) / l
    print(f"\nSoftmax result: {probs}")
    print(f"Sum: {probs.sum():.6f}")
    
    return probs, m, l

x = np.array([1.0, 3.0, 2.0, 5.0, 4.0, 6.0, 2.0, 1.0])
probs_online, _, _ = online_softmax_demo(x, block_size=2)

print(f"\nVerification against standard softmax:")
probs_standard = standard_softmax(x)
print(f"Standard: {probs_standard}")
print(f"Match: {np.allclose(probs_online, probs_standard)}")

---
## Step 3: The Concept — Online Softmax for Attention (10 min)

### Applying to Attention

In attention, we need:
$$\text{output} = \text{softmax}(\text{scores}) \times V$$

The online algorithm maintains:
- **m**: Running max of scores
- **l**: Running sum of exp(scores - m)
- **O**: Running **weighted sum** of values

The update rule for O is:
$$O_{new} = O_{old} \times \frac{l_{old}}{l_{new}} \times \exp(m_{old} - m_{new}) + \frac{\text{exp\_block}}{l_{new}} \times V_{block}$$

In [None]:
def online_attention_demo(Q, K, V, block_size=2):
    """
    Online attention computation - processes K/V in blocks.
    
    Q: [d_k] - single query
    K: [seq_len, d_k] - all keys
    V: [seq_len, d_v] - all values
    """
    seq_len = K.shape[0]
    d_v = V.shape[1]
    d_k = K.shape[1]
    scale = np.sqrt(d_k)
    
    # Initialize accumulators
    m = float('-inf')  # Running max
    l = 0.0            # Running sum of exp
    O = np.zeros(d_v)  # Running output (unnormalized at first)
    
    print("Online Attention Computation")
    print("=" * 60)
    
    for block_start in range(0, seq_len, block_size):
        block_end = min(block_start + block_size, seq_len)
        
        # Get block of K, V
        K_block = K[block_start:block_end]
        V_block = V[block_start:block_end]
        
        # Compute scores for this block
        scores_block = (Q @ K_block.T) / scale
        
        print(f"\nBlock [{block_start}:{block_end}]")
        print(f"  Scores: {scores_block}")
        
        # Block statistics
        m_block = scores_block.max()
        m_old = m
        m_new = max(m, m_block)
        
        print(f"  Block max: {m_block:.4f}, Global max: {m_old:.4f} → {m_new:.4f}")
        
        # Rescale old accumulator
        if m_old != float('-inf'):
            correction = np.exp(m_old - m_new)
            l_old = l
            l = l * correction
            O = O * correction  # Also rescale the output accumulator!
            print(f"  Rescaling: correction = {correction:.6f}")
        
        # Process new block
        exp_block = np.exp(scores_block - m_new)
        l += exp_block.sum()
        
        # Accumulate weighted values (using unnormalized weights for now)
        O += exp_block @ V_block
        
        m = m_new
        
        print(f"  exp(scores - m): {exp_block}")
        print(f"  Running l: {l:.6f}")
    
    # Final normalization
    O = O / l
    
    print(f"\n{'='*60}")
    print(f"Final output: {O}")
    
    return O, m, l

# Test
np.random.seed(42)
d_k, d_v = 4, 4
seq_len = 6

Q = np.random.randn(d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)

output_online, _, _ = online_attention_demo(Q, K, V, block_size=2)

In [None]:
# Verify against standard attention
def standard_attention(Q, K, V):
    """Standard attention for single query."""
    d_k = K.shape[1]
    scores = (Q @ K.T) / np.sqrt(d_k)
    weights = standard_softmax(scores)
    return weights @ V

output_standard = standard_attention(Q, K, V)

print(f"\nVerification:")
print(f"Online:   {output_online}")
print(f"Standard: {output_standard}")
print(f"Max diff: {np.abs(output_online - output_standard).max():.2e}")

---
## Step 4: Code It — Efficient Online Softmax (30 min)

### Clean Implementation

In [None]:
def online_softmax(x, block_size):
    """
    Compute softmax using online algorithm.
    
    Args:
        x: Input array
        block_size: Process this many elements at a time
    
    Returns:
        Softmax probabilities (same shape as x)
    """
    n = len(x)
    
    # First pass: compute m (max) and l (sum of exp)
    m = float('-inf')
    l = 0.0
    
    for start in range(0, n, block_size):
        block = x[start:start + block_size]
        m_block = block.max()
        
        # Update max and rescale sum
        m_new = max(m, m_block)
        l = l * np.exp(m - m_new) + np.sum(np.exp(block - m_new))
        m = m_new
    
    # Second pass: compute final probabilities
    return np.exp(x - m) / l

# Test various block sizes
x = np.random.randn(100)
probs_standard = standard_softmax(x)

print("Block size | Max error")
print("-" * 25)
for bs in [1, 2, 5, 10, 25, 50, 100]:
    probs_online = online_softmax(x, bs)
    error = np.abs(probs_standard - probs_online).max()
    print(f"{bs:10d} | {error:.2e}")

### Online Attention (Full Implementation)

In [None]:
def online_attention(Q, K, V, block_size):
    """
    Online attention for a single query.
    Processes K/V in blocks, never materializing full attention matrix.
    
    Args:
        Q: [d_k] single query vector
        K: [seq_len, d_k] key matrix
        V: [seq_len, d_v] value matrix
        block_size: Number of K/V pairs to process at once
    
    Returns:
        Output vector [d_v]
    """
    seq_len, d_k = K.shape
    d_v = V.shape[1]
    scale = np.sqrt(d_k)
    
    # Initialize
    m = float('-inf')
    l = 0.0
    O = np.zeros(d_v)
    
    for start in range(0, seq_len, block_size):
        end = min(start + block_size, seq_len)
        
        # Compute scores for this block
        scores = (Q @ K[start:end].T) / scale
        
        # Update statistics
        m_block = scores.max()
        m_new = max(m, m_block)
        
        # Rescale old accumulators
        correction = np.exp(m - m_new) if m != float('-inf') else 1.0
        l = l * correction
        O = O * correction
        
        # Add new block
        exp_scores = np.exp(scores - m_new)
        l += exp_scores.sum()
        O += exp_scores @ V[start:end]
        
        m = m_new
    
    # Final normalization
    return O / l

def online_attention_batched(Q, K, V, block_size):
    """
    Online attention for multiple queries.
    
    Args:
        Q: [num_queries, d_k]
        K: [seq_len, d_k]
        V: [seq_len, d_v]
        block_size: K/V block size
    
    Returns:
        Output [num_queries, d_v]
    """
    num_queries = Q.shape[0]
    seq_len, d_k = K.shape
    d_v = V.shape[1]
    scale = np.sqrt(d_k)
    
    # Initialize per-query accumulators
    m = np.full(num_queries, float('-inf'))
    l = np.zeros(num_queries)
    O = np.zeros((num_queries, d_v))
    
    for start in range(0, seq_len, block_size):
        end = min(start + block_size, seq_len)
        
        K_block = K[start:end]  # [block_size, d_k]
        V_block = V[start:end]  # [block_size, d_v]
        
        # Scores: [num_queries, block_size]
        scores = (Q @ K_block.T) / scale
        
        # Per-query max for this block
        m_block = scores.max(axis=1)  # [num_queries]
        m_new = np.maximum(m, m_block)
        
        # Rescale (handle -inf case)
        correction = np.exp(np.where(m == float('-inf'), 0, m - m_new))
        l = l * correction
        O = O * correction[:, None]
        
        # Add block contribution
        exp_scores = np.exp(scores - m_new[:, None])  # [num_queries, block_size]
        l += exp_scores.sum(axis=1)
        O += exp_scores @ V_block  # [num_queries, d_v]
        
        m = m_new
    
    return O / l[:, None]

# Test batched version
np.random.seed(42)
num_queries = 4
seq_len = 16
d_k = d_v = 8

Q = np.random.randn(num_queries, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)

# Online attention
output_online = online_attention_batched(Q, K, V, block_size=4)

# Standard attention
scores = (Q @ K.T) / np.sqrt(d_k)
weights = np.exp(scores - scores.max(axis=1, keepdims=True))
weights = weights / weights.sum(axis=1, keepdims=True)
output_standard = weights @ V

print(f"Max difference: {np.abs(output_online - output_standard).max():.2e}")

### PyTorch Implementation

In [None]:
def online_attention_torch(Q, K, V, block_size):
    """
    Online attention in PyTorch.
    
    Q: [batch, num_queries, d_k]
    K: [batch, seq_len, d_k]
    V: [batch, seq_len, d_v]
    """
    batch, num_queries, d_k = Q.shape
    seq_len = K.shape[1]
    d_v = V.shape[2]
    scale = d_k ** 0.5
    
    # Initialize accumulators
    m = torch.full((batch, num_queries), float('-inf'), device=Q.device)
    l = torch.zeros((batch, num_queries), device=Q.device)
    O = torch.zeros((batch, num_queries, d_v), device=Q.device)
    
    for start in range(0, seq_len, block_size):
        end = min(start + block_size, seq_len)
        
        K_block = K[:, start:end, :]  # [batch, block_size, d_k]
        V_block = V[:, start:end, :]  # [batch, block_size, d_v]
        
        # Scores: [batch, num_queries, block_size]
        scores = torch.bmm(Q, K_block.transpose(-2, -1)) / scale
        
        # Block max: [batch, num_queries]
        m_block = scores.max(dim=-1).values
        m_new = torch.maximum(m, m_block)
        
        # Rescale
        correction = torch.exp(torch.where(
            m == float('-inf'),
            torch.zeros_like(m),
            m - m_new
        ))
        l = l * correction
        O = O * correction.unsqueeze(-1)
        
        # Add block
        exp_scores = torch.exp(scores - m_new.unsqueeze(-1))
        l = l + exp_scores.sum(dim=-1)
        O = O + torch.bmm(exp_scores, V_block)
        
        m = m_new
    
    return O / l.unsqueeze(-1)

# Test
Q_t = torch.randn(2, 8, 16)  # [batch, queries, d_k]
K_t = torch.randn(2, 32, 16) # [batch, seq_len, d_k]
V_t = torch.randn(2, 32, 16) # [batch, seq_len, d_v]

output_online = online_attention_torch(Q_t, K_t, V_t, block_size=8)

# Compare with standard attention
scores = torch.bmm(Q_t, K_t.transpose(-2, -1)) / (16 ** 0.5)
weights = torch.softmax(scores, dim=-1)
output_standard = torch.bmm(weights, V_t)

print(f"Max difference: {(output_online - output_standard).abs().max():.2e}")

### Visualization: Online vs Standard Memory

In [None]:
def memory_comparison(seq_lengths, block_size=64, d_model=64):
    """
    Compare memory usage: standard vs online attention.
    """
    results = []
    
    for n in seq_lengths:
        # Standard: stores full N×N attention matrix
        standard_memory = n * n * 4  # bytes (FP32)
        
        # Online: stores only block_size elements at a time
        online_memory = n * block_size * 4 + n * d_model * 4  # scores + output
        
        results.append({
            'seq_len': n,
            'standard_mb': standard_memory / 1e6,
            'online_mb': online_memory / 1e6,
            'savings': standard_memory / online_memory
        })
    
    return results

seq_lengths = [256, 512, 1024, 2048, 4096, 8192, 16384]
results = memory_comparison(seq_lengths, block_size=64)

print(f"{'Seq Len':>10} {'Standard':>12} {'Online':>12} {'Savings':>10}")
print("-" * 50)
for r in results:
    print(f"{r['seq_len']:>10} {r['standard_mb']:>10.1f}MB {r['online_mb']:>10.1f}MB {r['savings']:>9.1f}x")

In [None]:
# Plot the comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Memory usage
standard_mem = [r['standard_mb'] for r in results]
online_mem = [r['online_mb'] for r in results]

axes[0].semilogy(seq_lengths, standard_mem, 'r-o', label='Standard (O(N²))', linewidth=2)
axes[0].semilogy(seq_lengths, online_mem, 'b-o', label='Online (O(N))', linewidth=2)
axes[0].set_xlabel('Sequence Length')
axes[0].set_ylabel('Memory (MB, log scale)')
axes[0].set_title('Memory Usage Comparison')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Savings factor
savings = [r['savings'] for r in results]
axes[1].plot(seq_lengths, savings, 'g-o', linewidth=2)
axes[1].set_xlabel('Sequence Length')
axes[1].set_ylabel('Memory Savings (×)')
axes[1].set_title('Online Memory Savings (grows with N)')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
## Step 5: Verify — Quiz & Reflection (10 min)

### Quiz

In [None]:
def check_answer(question, your_answer, correct_answer):
    if your_answer == correct_answer:
        print(f"✓ Correct! {question}")
    else:
        print(f"✗ Incorrect. Your answer: {your_answer}, Correct: {correct_answer}")

# Q1: What running statistics does online softmax maintain?
# a) Only max
# b) Max and sum
# c) Max, sum, and count
# d) Mean and variance
q1_answer = 'b'
check_answer("Running statistics", q1_answer, 'b')

In [None]:
# Q2: When a new max is found, how do we update the running sum?
# a) Reset it to 0
# b) Add the new max
# c) Multiply by exp(old_max - new_max)
# d) Divide by the new max
q2_answer = 'c'
check_answer("Updating running sum", q2_answer, 'c')

In [None]:
# Q3: Online softmax reduces memory from O(N²) to:
# a) O(N)
# b) O(N log N)
# c) O(√N)
# d) O(1)
q3_answer = 'a'
check_answer("Memory reduction", q3_answer, 'a')

In [None]:
# Q4: Why is the correction factor exp(old_max - new_max) always ≤ 1?
# a) Because old_max ≤ new_max (max can only increase)
# b) Because exp is always positive
# c) Because we normalize at the end
# d) It's not always ≤ 1
q4_answer = 'a'
check_answer("Correction factor", q4_answer, 'a')

### Reflection Questions

1. **Two passes:** Online softmax still needs two passes over the data (one for m/l, one to compute final probs). Can FlashAttention do better?

2. **Numerical precision:** Why might online softmax have slightly different numerical errors than standard softmax?

3. **Block size choice:** How does block size affect (a) memory usage, (b) numerical accuracy, (c) performance?

---

## Summary

| Concept | Key Insight |
|---------|------------|
| Shift invariance | softmax(x - c) = softmax(x), so we can update when max changes |
| Rescaling | When max changes, multiply old sum by exp(old_max - new_max) |
| Memory reduction | O(N²) → O(N) by processing in blocks |
| Output accumulation | Track weighted sum and rescale along with the sum |

**Tomorrow:** Tiled attention — applying online softmax to compute attention block-by-block on GPU.

---

**Interactive Reference:** [attention-math.html](../attention-math.html) Section 3 — Online Softmax Simulation