# FlashAttention: Solving the Memory Bottleneck

In the [previous notebook](./standard_attention.ipynb), we explored why attention becomes a memory bottleneck:

- **No tiling / 1D tiling**: O(N) HBM traffic, but can't scale — the N×N attention matrix (or K+V) overflows SRAM
- **2D tiling**: Scales to any N, but has O(N²) HBM traffic because softmax needs full rows, forcing us to write/read the N×N matrices S and P through HBM

**The question**: Can we do 2D tiling without materializing S and P?

**FlashAttention's insight**: We don't actually need the full row of S to compute softmax! We can compute it *incrementally* using **online softmax** — maintaining running statistics that let us correct partial results as we see more data. This eliminates the S and P round-trips, significantly reducing HBM traffic.

In [1]:
# Setup: Import GPU simulator and 2D tiled attention for comparison
from llms_for_dummies.gpu_sim import GPUSpec, Profiler, Tensor
import math

gpu = GPUSpec.sim_a100()
Tensor.verbose = False  # Less noise for this notebook

print(f"GPU: {gpu.name}")
print(f"SRAM capacity: {gpu.sram_size:,} floats ({gpu.sram_size * gpu.bytes_per_float / 1024:.0f} KB)")

GPU: Simulated A100
SRAM capacity: 131,072 floats (256 KB)


## 1. The Online Softmax Trick

Standard softmax requires the **global max** and **global sum** before generating any output:

$$
\text{softmax}(\textbf{x})_i = \frac{e^{\textbf{x}_i - m}}{l}
$$

Where $m = \max(\textbf{x})$ and $l = \sum_j e^{\textbf{x}_j - m}$.

**The Problem:** We can't compute $m$ or $l$ without seeing the entire row (all $N$ tokens), which requires huge HBM reads/writes.

**The Solution:** Compute softmax **iteratively**. As we load new chunks of $K$ and $V$, we update our running statistics and rescale previous results on the fly.

### The Iterative Update Logic

Let's say we have processed some previous blocks and have a running state ($m_{old}, l_{old}, O_{old}$). We now load a **new block** of size $B_c$.

**1. Compute Local Block Statistics**

Calculate the max and unnormalized scores for just the current tile:
$$m_{block} = \max(S_{block})$$
$$P_{block} = e^{S_{block} - m_{block}}$$
$$l_{block} = \sum P_{block}$$

**2. Update Global Statistics**

Compare the new block's max with our running max:
$$m_{new} = \max(m_{old}, m_{block})$$

Calculate **rescale factors** to shift everything to the common baseline $m_{new}$:
$$\alpha = e^{m_{old} - m_{new}} \quad \text{(Shrink factor for history)}$$
$$\beta = e^{m_{block} - m_{new}} \quad \text{(Shrink factor for current block)}$$

Update the running sum (denominator):
$$l_{new} = \underbrace{l_{old} \cdot \alpha}_{\text{decayed history}} + \underbrace{l_{block} \cdot \beta}_{\text{new contribution}}$$

**3. Update the Output Accumulator ($O$)**

This is the critical step, we accumulate the weighted sum directly into $O$. We must **decay the old accumulator** so it matches the scale of the new block.

$$
O_{new} = \underbrace{O_{old} \cdot \alpha}_{\text{Rescale old sums}} + \underbrace{(P_{block} \cdot V_{block}) \beta}_{\text{Add new weighted contribution}}
$$

> **Key Insight:** We never materialize the full $P$ (or $S$) matrices. We compute the contribution of a small tile ($P_{block} \times V_{block}$), add it to our running total, and immediately discard $P_{block}$.

**4. Final Normalization**

After iterating through all blocks, $O_{final}$ contains the unnormalized weighted sum. One final division gives the correct attention output:

$$
\text{Attention}(Q, K, V) = \frac{O_{final}}{l_{final}}
$$

### Worked Example: Online Softmax in Action

Let's trace through online softmax with concrete numbers. We'll compute attention for a single query attending to 5 keys, processing them in two blocks:

In [2]:
import numpy as np

# We are calculating Attention for a SINGLE Query token (i)
# attending to 5 Key/Value tokens.
d = 3  # Head dimension (vector size)

# 1. Attention Scores (S_i) = Q_i @ K.T
#    This is the row of raw dot products before softmax.
S_row = np.array([1.0, 4.0, 2.0, 5.0, 3.0])  

# 2. Value Matrix (V)
#    Each of the 5 tokens has a Value vector of size d=3
V = np.array([
    [0.1, 0.2, 0.3],  # Token 0
    [1.0, 1.0, 1.0],  # Token 1
    [0.5, 0.0, 0.5],  # Token 2
    [2.0, 2.0, 0.0],  # Token 3
    [0.1, 0.8, 0.1],  # Token 4
])

print(f"INPUTS:")
print(f"  Scores (S_row): {S_row}")
print(f"  Values (V): \n{V}")
print("=" * 60)

# ==========================================
# Step 1: Process First Block (The History)
# ==========================================
# We load the first 2 tokens (Block 1)
print("STEP 1: Processing Block 1 (Tokens 0, 1)")
S_block = S_row[:2]      # [1.0, 4.0]
V_block = V[:2]          # First two rows of V

# Local computations
m_old = S_block.max()                  # Scalar: 4.0
P_block = np.exp(S_block - m_old)      # Vector: [0.05, 1.0]
l_old = P_block.sum()                  # Scalar: 1.05

# Compute Accumulator (Weighted Sum of V vectors)
# Shape: (2,) @ (2, 3) -> (3,)
O_old = P_block @ V_block                   

print(f"  [Stats] m_old={m_old}, l_old={l_old:.4f}")
print(f"  [Accumulator O_old] {O_old}  <-- Vector of size {d}")
print("-" * 60)

# ==========================================
# Step 2: Process Second Block (The Update)
# ==========================================
# We load the next 3 tokens (Block 2)
print("STEP 2: Processing Block 2 (Tokens 2, 3, 4)")
S_block = S_row[2:]      # [2.0, 5.0, 3.0]
V_block = V[2:]          # Last three rows of V

# Local computations for the new block
m_block = S_block.max()                # Scalar: 5.0
P_block = np.exp(S_block - m_block)    # Vector: [0.05, 1.0, 0.14]
l_block = P_block.sum()                # Scalar: 1.185...

# Compute Contribution from this block
O_contrib = P_block @ V_block          # Shape (3,)

print(f"  [Stats] m_block={m_block}, l_block={l_block:.4f}")
print(f"  [Block Contrib] {O_contrib}")
print("-" * 60)

# ==========================================
# Step 3: The Online Softmax Merge
# ==========================================
print("STEP 3: Merging...")

# 1. Update Global Max
m_new = max(m_old, m_block)

# 2. Calculate Rescale Factors
#    alpha: decays the old history
#    beta:  decays the new block (usually 1.0 if new block has higher max)
alpha = np.exp(m_old - m_new)  # exp(4-5) = 0.3679
beta  = np.exp(m_block - m_new) # exp(5-5) = 1.0

# 3. Update Global Sum
l_new = (l_old * alpha) + (l_block * beta)

# 4. Update Output Accumulator (Vector operation)
#    Formula: O_new = (O_old * alpha) + (O_contrib * beta)
O_new = (O_old * alpha) + (O_contrib * beta)

print(f"  [Rescale] alpha={alpha:.4f}, beta={beta:.4f}")
print(f"  [Update]  O_new = ({O_old} * {alpha:.2f}) + ({O_contrib} * {beta:.2f})")
print(f"  [Result]  O_new = {O_new}")
print("=" * 60)

# ==========================================
# Step 4: Finalize
# ==========================================
print("STEP 4: Final Normalization")

# Divide vector by scalar sum
final_output = O_new / l_new

print(f"  Final Output Vector: {final_output}")

# Verify against Naive Numpy
print("-" * 60)
full_P = np.exp(S_row - S_row.max())
full_softmax = full_P / full_P.sum()
true_output = full_softmax @ V
print(f"  Ground Truth:        {true_output}")
print(f"  Match? {np.allclose(final_output, true_output)}")

INPUTS:
  Scores (S_row): [1. 4. 2. 5. 3.]
  Values (V): 
[[0.1 0.2 0.3]
 [1.  1.  1. ]
 [0.5 0.  0.5]
 [2.  2.  0. ]
 [0.1 0.8 0.1]]
STEP 1: Processing Block 1 (Tokens 0, 1)
  [Stats] m_old=4.0, l_old=1.0498
  [Accumulator O_old] [1.00497871 1.00995741 1.01493612]  <-- Vector of size 3
------------------------------------------------------------
STEP 2: Processing Block 2 (Tokens 2, 3, 4)
  [Stats] m_block=5.0, l_block=1.1851
  [Block Contrib] [2.03842706 2.10826823 0.03842706]
------------------------------------------------------------
STEP 3: Merging...
  [Rescale] alpha=0.3679, beta=1.0000
  [Update]  O_new = ([1.00497871 1.00995741 1.01493612] * 0.37) + ([2.03842706 2.10826823 0.03842706] * 1.00)
  [Result]  O_new = [2.40813807 2.4798108  0.4118012 ]
STEP 4: Final Normalization
  Final Output Vector: [1.53255989 1.57817303 0.26207384]
------------------------------------------------------------
  Ground Truth:        [1.53255989 1.57817303 0.26207384]
  Match? True


### Why Does Rescaling Work?

The key insight is that **softmax is invariant to shifting by a constant**:
$$\text{softmax}(x - c) = \text{softmax}(x)$$

When we see a new chunk with a larger max ($m_{new}$), we need to "shift" our previous computations to align with it. The rescaling factor `exp(m_old - m_new)` accounts for this shift.

1. **Scaling Individual Terms:**
   Each old value is effectively scaled down to make room for the new maximum:

```
Old: exp(x - m_old)
New: exp(x - m_new) = exp(x - m_old) × exp(m_old - m_new)
                                        └─── rescale ───┘
```

2. **Scaling the Sum (`l`):**
Since *every* term in the old sum is scaled by the exact same factor, we can factor it out. This allows us to update the running sum without recalculating individual elements:

$$
\begin{aligned}
l_{new} &= \sum_{i} \left( \exp(x_i - m_{old}) \times \text{scale} \right) \\
&= \text{scale} \times \sum_{i} \exp(x_i - m_{old}) \\
&= \text{scale} \times l_{old}
\end{aligned}
$$

## 2. The FlashAttention Algorithm

For each row of Q, we maintain three running values as we iterate through K/V tiles:

| State | Meaning | Update Rule |
|-------|---------|-------------|
| `m` | Running max of S values seen so far | `m = max(m, max(S_tile))` |
| `l` | Running sum of exp(S - m) | `l = l × α + sum(exp(S_tile - m)) × β` |
| `O` | Running weighted sum (unnormalized output) | `O = O × α + (P_tile @ V_tile) × β` |

Where:
- `α = exp(m_old - m_new)` — rescales old values when max increases
- `β = exp(m_tile - m_new)` — rescales new tile values

After processing all K/V tiles, we finalize: `O_final = O / l`

**The magic**: We never store the full S or P matrices. Each tile is processed and immediately used to update our running state!

### FlashAttention Processing Pattern

```
For each Q tile (outer loop):                        
┌─────────────────────────────────────────────┐      
│ Q tile    ──┐                               │      
│  (Br×d)     │                               │      
└─────────────│───────────────────────────────┘      
              │                                       
              ▼                                       
┌─────────┬─────────┬─────────┬─────────┐           
│ K tile 1│ K tile 2│ K tile 3│   ...   │ ◄── inner loop
│  (Bc×d) │  (Bc×d) │  (Bc×d) │         │           
└────┬────┴────┬────┴────┬────┴─────────┘           
     │         │         │                           
┌────▼────┬────▼────┬────▼────┐                      
│ S tile 1│ S tile 2│ S tile 3│  Never written to HBM!
│(Br×Bc)  │(Br×Bc)  │(Br×Bc)  │  Immediately:
└────┬────┴────┬────┴────┬────┘  1. Update m, l      
     │         │         │       2. Compute P tile    
     │         │         │       3. Accumulate into O 
     ▼         ▼         ▼       4. Free from SRAM   
┌─────────────────────────────┐                      
│      Running State          │                      
│  m (Br×1): running max      │                      
│  l (Br×1): running sum      │                      
│  O (Br×d): running output   │                      
└──────────────┬──────────────┘                      
               │ After all K tiles                   
               ▼                                     
┌─────────────────────────────┐                      
│  O_final = O / l            │──▶ Write to HBM     
└─────────────────────────────┘
```

## 3. Implementation

Now let's implement FlashAttention with our GPU simulator. The code follows the pattern above, handling full matrices with proper tiling:

In [3]:
def flash_attention(prof: Profiler, N: int, d: int, tile_size: int):
    """
    FlashAttention: 2D tiling with online softmax.
    
    Key idea: Instead of computing full softmax rows, we:
    1. Process K/V in tiles
    2. Track running max (m) and sum (l) per row
    3. Rescale partial results as we see new tiles
    4. Never write S or P to HBM!
    
    HBM traffic: O(N²d²/M) where M is SRAM size. Still O(N²), but avoids
    the S and P round-trips that dominate naive 2D tiling. K and V are
    re-read once per Q tile — that's the remaining quadratic cost.
    """
    Q = Tensor((N, d), "Q", prof)
    K = Tensor((N, d), "K", prof)
    V = Tensor((N, d), "V", prof)
    
    scale_factor = 1.0 / math.sqrt(d)
    
    # Outer loop: process Q in row blocks
    for i in range(0, N, tile_size):
        i_end = min(i + tile_size, N)
        num_rows = i_end - i
        
        q_tile = Q.load(rows=(i, i_end))
        
        # Online softmax state — these are the running statistics from Section 1:
        m = Tensor.zeros((num_rows, 1), "m", prof)
        l = Tensor.zeros((num_rows, 1), "l", prof)
        o = Tensor.zeros((num_rows, d), "O_acc", prof)
        
        # Inner loop: iterate through K/V tiles
        for j in range(0, N, tile_size):
            j_end = min(j + tile_size, N)
            
            k_tile = K.load(rows=(j, j_end))
            v_tile = V.load(rows=(j, j_end))
            
            # --- Equation: S_block = Q_tile @ K_tile.T / sqrt(d) ---
            s_tile = (q_tile @ k_tile.T).scale(scale_factor, "/√d")
            k_tile.free()
            
            # --- Equation: m_block = max(S_block) ---
            m_tile = s_tile.rowmax()
            
            # --- Equation: P_block = exp(S_block - m_block) ---
            s_shifted = s_tile.sub_rowvec(m_tile)
            s_tile.free()
            p_tile = s_shifted.exp()
            s_shifted.free()
            
            # --- Equation: l_block = sum(P_block) ---
            l_tile = p_tile.rowsum()
            
            # --- Equation: m_new = max(m_old, m_block) ---
            m_new = Tensor.zeros((num_rows, 1), "m_new", prof)
            m_new.add_(m)
            m_new.max_(m_tile)
            
            # --- Equation: α = exp(m_old - m_new), β = exp(m_block - m_new) ---
            m_diff_old = m.sub_rowvec(m_new)
            m.free()
            alpha = m_diff_old.exp()           # α
            m_diff_old.free()
            
            m_diff_new = m_tile.sub_rowvec(m_new)
            m_tile.free()
            beta = m_diff_new.exp()            # β
            m_diff_new.free()
            
            # --- Equation: l_new = l_old * α + l_block * β ---
            l.mul_(alpha)                      # l_old * α  (in-place)
            l_tile.mul_(beta)                  # l_block * β (in-place)
            l.add_(l_tile)                     # l_old*α + l_block*β
            l_tile.free()
            
            # --- Equation: O_new = O_old * α + (P_block @ V_block) * β ---
            o_local = p_tile @ v_tile          # P_block @ V_block
            p_tile.free()
            v_tile.free()
            
            o.mul_(alpha)                      # O_old * α
            alpha.free()
            
            o_local.mul_(beta)                 # (P_block @ V_block) * β
            beta.free()
            
            o.add_(o_local)                    # O_new = O_old*α + (P@V)*β
            o_local.free()
            
            # Shift state for next iteration
            m = m_new
        
        # --- Equation: O_final = O / l_final ---
        o_final = o.div_rowvec(l)
        o.free()
        l.free()
        m.free()
        
        # Only the final output touches HBM — S and P never left SRAM!
        o_final.write_hbm()
        o_final.free()
        q_tile.free()

In [4]:
N, d = 32_768, 128
tile_size = gpu.optimal_tile_size_flash(d)

print(f"N={N}, d={d}, tile_size={tile_size}")
print(f"(Smaller than 2D naive because FlashAttention has more intermediate tensors)")
print()

Tensor.verbose = False
prof = Profiler(gpu, f"FlashAttention (N={N})")
flash_attention(prof, N, d, tile_size)
prof.report()

N=32768, d=128, tile_size=158
(Smaller than 2D naive because FlashAttention has more intermediate tensors)

FlashAttention (N=32768)

Memory Usage:
  Peak HBM:    24.0 MB
  Peak SRAM:  │████████████···│  85%

HBM Traffic:
  Reads:    1,749,024,768 floats (3336.0 MB)
  Writes:       4,194,304 floats (8.0 MB)
  Total:    1,753,219,072 floats (3344.0 MB)

Time Breakdown:
  ┌──────────────────────────────────────────────────┐
  │██████████████████████████████████░░░░░░░░░░░░░░░░│
  └──────────────────────────────────────────────────┘
   Computing (68%)                  Waiting for HBM (32%)

→ Compute-bound (GPU is busy computing)


## 4. Comparison: 2D Tiling vs FlashAttention

Let's compare naive 2D tiling (from the previous notebook) against FlashAttention on the same problem:

In [5]:
# Import the 2D tiled attention from the previous notebook for comparison
from ipynb.fs.defs.standard_attention import attention_2d_tiled

Tensor.verbose= False

N, d = 32_768, 128
tile_size_2d = gpu.optimal_tile_size_2d(d)
tile_size_flash = gpu.optimal_tile_size_flash(d)

# Run both 2D tiled and FlashAttention
prof_2d = Profiler(gpu, "2D Tiled (naive)")
attention_2d_tiled(prof_2d, N, d, tile_size_2d)

prof_flash = Profiler(gpu, "FlashAttention")
flash_attention(prof_flash, N, d, tile_size_flash)

# Compare
print(f"{'='*60}")
print(f"Comparison: N={N}, d={d}")
print(f"{'='*60}")
print()
print(f"{'Metric':<25} {'2D Tiled':>15} {'FlashAttention':>15}")
print(f"{'-'*60}")
print(f"{'Tile size':.<25} {tile_size_2d:>15} {tile_size_flash:>15}")
print(f"{'HBM Reads':.<25} {prof_2d.total_hbm_reads:>15,} {prof_flash.total_hbm_reads:>15,}")
print(f"{'HBM Writes':.<25} {prof_2d.total_hbm_writes:>15,} {prof_flash.total_hbm_writes:>15,}")
print(f"{'Total HBM Traffic':.<25} {prof_2d.total_hbm_reads + prof_2d.total_hbm_writes:>15,} {prof_flash.total_hbm_reads + prof_flash.total_hbm_writes:>15,}")
print(f"{'Peak SRAM (%)':.<25} {prof_2d.peak_sram_usage / gpu.sram_size * 100:>14.0f}% {prof_flash.peak_sram_usage / gpu.sram_size * 100:>14.0f}%")
print()

speedup = (prof_2d.total_hbm_reads + prof_2d.total_hbm_writes) / (prof_flash.total_hbm_reads + prof_flash.total_hbm_writes)
print(f"FlashAttention uses {speedup:.1f}× less HBM traffic!")

Comparison: N=32768, d=128

Metric                           2D Tiled  FlashAttention
------------------------------------------------------------
Tile size................             217             158
HBM Reads................   3,426,746,368   1,749,024,768
HBM Writes...............   2,151,677,952       4,194,304
Total HBM Traffic........   5,578,424,320   1,753,219,072
Peak SRAM (%)............            100%             85%

FlashAttention uses 3.2× less HBM traffic!


### How Does This Scale with Sequence Length?

The single comparison above shows a 3.2x difference at N=32,768. But how does the ratio behave as N grows?

Both approaches are O(N²) — in both cases, K and V are re-read for each Q tile. The difference is a **constant factor**: naive 2D tiling also reads/writes the N×N matrices S and P through HBM, while FlashAttention avoids those entirely. So we expect a roughly constant ratio:

In [6]:
d = 128
print(f"{'N':>8}  {'2D Tiled Traffic':>18}  {'Flash Traffic':>18}  {'Ratio':>7}")
print(f"{'-'*60}")

for N in [1024, 4096, 16384, 32768, 65536, 131072]:
    tile_2d = gpu.optimal_tile_size_2d(d)
    tile_flash = gpu.optimal_tile_size_flash(d)

    p2d = Profiler(gpu, "2D")
    attention_2d_tiled(p2d, N, d, tile_2d)
    traffic_2d = p2d.total_hbm_reads + p2d.total_hbm_writes

    pf = Profiler(gpu, "Flash")
    flash_attention(pf, N, d, tile_flash)
    traffic_flash = pf.total_hbm_reads + pf.total_hbm_writes

    ratio = traffic_2d / traffic_flash
    print(f"{N:>8,}  {traffic_2d * 2 / 1024**2:>14.0f} MB  {traffic_flash * 2 / 1024**2:>14.0f} MB  {ratio:>6.1f}x")

print()
print("The ratio is roughly constant — both are O(N²).")
print("The ~3x improvement comes from eliminating the S and P round-trips.")

       N    2D Tiled Traffic       Flash Traffic    Ratio
------------------------------------------------------------
   1,024              11 MB               4 MB     2.8x
   4,096             168 MB              54 MB     3.1x
  16,384            2664 MB             840 MB     3.2x
  32,768           10640 MB            3344 MB     3.2x
  65,536           42496 MB           13312 MB     3.2x
 131,072          169856 MB           53184 MB     3.2x

The ratio is roughly constant — both are O(N²).
The ~3x improvement comes from eliminating the S and P round-trips.


## 5. Summary

FlashAttention eliminates the dominant cost of naive 2D tiling — the S and P round-trips through HBM — by computing softmax online. S and P are never materialized; each tile is computed in SRAM, used immediately to update the running output, and discarded.

**What remains**: K and V are still re-read for each Q tile (N/Br times each), giving O(N²d²/M) total traffic. This is still O(N²), but with a much better constant than naive 2D tiling — roughly 3x less traffic in practice.

**Where does the 3x come from?** Naive 2D tiling writes S to HBM, reads it back for softmax, writes P, and reads P back for the final matmul — that's ~4 passes over an N×N matrix. FlashAttention avoids all of these. The remaining cost (re-reading K and V) is shared by both approaches.

**Why it matters**: Even a constant factor improvement is huge when you're memory-bound. The profiler shows FlashAttention shifts the bottleneck from memory to compute — meaning the GPU is actually doing useful work instead of waiting for HBM. This is why FlashAttention is used in modern LLM training and inference.