# Week 3, Day 7: FlashAttention — The Complete Implementation

**Time:** ~1.5 hours

**Goal:** Implement the complete FlashAttention algorithm in Triton, achieving O(N) memory and competitive speed.

## The Journey Complete

This week we built up to this moment:
- **Day 1:** Dot products — the foundation of attention
- **Day 2:** Softmax overflow — why naive implementation fails
- **Day 3:** Stable softmax — the max-subtraction trick
- **Day 4:** Full attention — the quadratic memory problem
- **Day 5:** Online softmax — streaming computation
- **Day 6:** Tiled attention — block-wise processing
- **Day 7:** FlashAttention — putting it all together

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import time

torch.set_printoptions(precision=4, sci_mode=False)

---
## Step 1: The Challenge — FlashAttention Goals (5 min)

FlashAttention achieves:

| Metric | Standard | FlashAttention |
|--------|----------|----------------|
| Memory | O(N²) | O(N) |
| I/O | O(N²) | O(N²/M) where M = SRAM size |
| Speed | Baseline | 2-4x faster |

**Key innovations:**
1. Online softmax with output rescaling
2. Tiling to fit in SRAM
3. Fused kernel (no intermediate writes to HBM)

---
## Step 2: The Algorithm — FlashAttention Pseudocode (15 min)

### FlashAttention Forward Pass

```
Input: Q, K, V ∈ R^(N×d), block sizes B_q, B_kv
Output: O ∈ R^(N×d)

1. Divide Q into T_q = N/B_q blocks, K,V into T_kv = N/B_kv blocks
2. For each Q block i = 1...T_q:
   a. Load Q_i from HBM to SRAM
   b. Initialize: O_i = 0, l_i = 0, m_i = -∞ (all in SRAM)
   c. For each K,V block j = 1...T_kv:
      i.   Load K_j, V_j from HBM to SRAM
      ii.  Compute S_ij = Q_i × K_j^T (in SRAM)
      iii. Compute m̃_ij = rowmax(S_ij)
      iv.  Compute P̃_ij = exp(S_ij - m̃_ij)
      v.   Compute l̃_ij = rowsum(P̃_ij)
      vi.  Compute m_new = max(m_i, m̃_ij)
      vii. Compute l_new = e^(m_i - m_new) × l_i + e^(m̃_ij - m_new) × l̃_ij
      viii.Update O_i = (l_i × e^(m_i - m_new) × O_i + e^(m̃_ij - m_new) × P̃_ij × V_j) / l_new
      ix.  Update m_i = m_new, l_i = l_new
   d. Write O_i to HBM
```

In [None]:
def flash_attention_reference(Q, K, V, block_q=64, block_kv=64, causal=False):
    """
    Reference FlashAttention implementation (Python).
    Follows the algorithm exactly for clarity.
    """
    batch, seq_len, d = Q.shape
    scale = d ** -0.5
    
    O = torch.zeros_like(Q)
    L = torch.zeros(batch, seq_len, device=Q.device)  # For debugging: stores final l
    M = torch.full((batch, seq_len), float('-inf'), device=Q.device)  # Final m
    
    # Number of blocks
    T_q = (seq_len + block_q - 1) // block_q
    T_kv = (seq_len + block_kv - 1) // block_kv
    
    for i in range(T_q):  # Loop over Q blocks
        q_start = i * block_q
        q_end = min(q_start + block_q, seq_len)
        
        # Load Q block
        Q_i = Q[:, q_start:q_end, :]  # [batch, B_q, d]
        
        # Initialize accumulators for this Q block
        O_i = torch.zeros_like(Q_i)  # [batch, B_q, d]
        l_i = torch.zeros(batch, q_end - q_start, device=Q.device)  # [batch, B_q]
        m_i = torch.full((batch, q_end - q_start), float('-inf'), device=Q.device)
        
        for j in range(T_kv):  # Loop over K,V blocks
            kv_start = j * block_kv
            kv_end = min(kv_start + block_kv, seq_len)
            
            # Causal: skip blocks entirely in the future
            if causal and kv_start >= q_end:
                break
            
            # Load K, V blocks
            K_j = K[:, kv_start:kv_end, :]  # [batch, B_kv, d]
            V_j = V[:, kv_start:kv_end, :]  # [batch, B_kv, d]
            
            # Compute attention scores
            S_ij = torch.bmm(Q_i, K_j.transpose(-2, -1)) * scale  # [batch, B_q, B_kv]
            
            # Apply causal mask
            if causal:
                q_idx = torch.arange(q_start, q_end, device=Q.device).view(-1, 1)
                k_idx = torch.arange(kv_start, kv_end, device=Q.device).view(1, -1)
                mask = k_idx > q_idx  # [B_q, B_kv]
                S_ij = S_ij.masked_fill(mask.unsqueeze(0), float('-inf'))
            
            # Block max and exp
            m_ij = S_ij.max(dim=-1).values  # [batch, B_q]
            P_ij = torch.exp(S_ij - m_ij.unsqueeze(-1))  # [batch, B_q, B_kv]
            P_ij = torch.where(torch.isinf(S_ij), torch.zeros_like(P_ij), P_ij)
            l_ij = P_ij.sum(dim=-1)  # [batch, B_q]
            
            # Update max
            m_new = torch.maximum(m_i, m_ij)
            
            # Compute scaling factors
            alpha = torch.exp(m_i - m_new)  # Scale for old accumulator
            alpha = torch.where(m_i == float('-inf'), torch.zeros_like(alpha), alpha)
            beta = torch.exp(m_ij - m_new)   # Scale for new block
            beta = torch.where(m_ij == float('-inf'), torch.zeros_like(beta), beta)
            
            # Update sum
            l_new = alpha * l_i + beta * l_ij
            
            # Update output
            # O_new = (l_i * alpha * O_i + beta * P_ij @ V_j) / l_new
            O_i = (l_i.unsqueeze(-1) * alpha.unsqueeze(-1) * O_i + 
                   beta.unsqueeze(-1) * torch.bmm(P_ij, V_j))
            
            # Normalize (avoid div by zero)
            l_new_safe = torch.where(l_new == 0, torch.ones_like(l_new), l_new)
            O_i = O_i / l_new_safe.unsqueeze(-1)
            
            # Update accumulators
            m_i = m_new
            l_i = l_new
        
        # Store results
        O[:, q_start:q_end, :] = O_i
        L[:, q_start:q_end] = l_i
        M[:, q_start:q_end] = m_i
    
    return O

# Test
torch.manual_seed(42)
batch, seq_len, d = 2, 64, 32
Q = torch.randn(batch, seq_len, d)
K = torch.randn(batch, seq_len, d)
V = torch.randn(batch, seq_len, d)

# Standard attention
scores = torch.bmm(Q, K.transpose(-2, -1)) / (d ** 0.5)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores = scores.masked_fill(mask, float('-inf'))
weights = F.softmax(scores, dim=-1)
output_std = torch.bmm(weights, V)

# FlashAttention
output_flash = flash_attention_reference(Q, K, V, block_q=16, block_kv=16, causal=True)

print(f"Max difference: {(output_std - output_flash).abs().max():.2e}")

---
## Step 3: Triton Implementation (30 min)

### FlashAttention Kernel

In [None]:
@triton.jit
def flash_attention_kernel(
    Q_ptr, K_ptr, V_ptr, O_ptr,
    stride_qb, stride_qm, stride_qk,
    stride_kb, stride_kn, stride_kk,
    stride_vb, stride_vn, stride_vd,
    stride_ob, stride_om, stride_od,
    seq_len, d_model,
    scale,
    BLOCK_M: tl.constexpr,  # Q block size
    BLOCK_N: tl.constexpr,  # K/V block size
    BLOCK_D: tl.constexpr,  # Head dimension (must cover full d)
    CAUSAL: tl.constexpr,
):
    """
    FlashAttention forward kernel.
    
    Grid: (num_q_blocks, batch)
    Each program processes one Q block for one batch element.
    """
    # Program IDs
    q_block_idx = tl.program_id(0)
    batch_idx = tl.program_id(1)
    
    # Compute Q block start position
    q_start = q_block_idx * BLOCK_M
    
    # Offsets within the block
    offs_m = q_start + tl.arange(0, BLOCK_M)  # Q positions
    offs_n = tl.arange(0, BLOCK_N)  # K/V positions (will be updated in loop)
    offs_d = tl.arange(0, BLOCK_D)  # Head dimension
    
    # Pointers to Q block (stays constant)
    q_ptrs = (Q_ptr + 
              batch_idx * stride_qb + 
              offs_m[:, None] * stride_qm + 
              offs_d[None, :] * stride_qk)
    
    # Load Q block
    q_mask = (offs_m[:, None] < seq_len) & (offs_d[None, :] < d_model)
    Q_block = tl.load(q_ptrs, mask=q_mask, other=0.0)
    
    # Initialize accumulators
    m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)  # Max scores
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)  # Sum of exp
    O_i = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)  # Output accumulator
    
    # Determine K/V block range
    if CAUSAL:
        kv_block_end = (q_start + BLOCK_M + BLOCK_N - 1) // BLOCK_N
    else:
        kv_block_end = (seq_len + BLOCK_N - 1) // BLOCK_N
    
    # Loop over K/V blocks
    for kv_block_idx in range(0, kv_block_end):
        kv_start = kv_block_idx * BLOCK_N
        offs_n = kv_start + tl.arange(0, BLOCK_N)
        
        # Pointers to K, V blocks
        k_ptrs = (K_ptr + 
                  batch_idx * stride_kb + 
                  offs_n[:, None] * stride_kn + 
                  offs_d[None, :] * stride_kk)
        v_ptrs = (V_ptr + 
                  batch_idx * stride_vb + 
                  offs_n[:, None] * stride_vn + 
                  offs_d[None, :] * stride_vd)
        
        # Load K, V blocks
        kv_mask = (offs_n[:, None] < seq_len) & (offs_d[None, :] < d_model)
        K_block = tl.load(k_ptrs, mask=kv_mask, other=0.0)
        V_block = tl.load(v_ptrs, mask=kv_mask, other=0.0)
        
        # Compute attention scores: Q @ K^T
        # [BLOCK_M, BLOCK_D] @ [BLOCK_D, BLOCK_N] = [BLOCK_M, BLOCK_N]
        S = tl.dot(Q_block, tl.trans(K_block)) * scale
        
        # Apply causal mask
        if CAUSAL:
            causal_mask = offs_m[:, None] < offs_n[None, :]
            S = tl.where(causal_mask, float('-inf'), S)
        
        # Also mask out-of-bounds positions
        S = tl.where(offs_n[None, :] >= seq_len, float('-inf'), S)
        
        # Compute block max and exp
        m_ij = tl.max(S, axis=1)  # [BLOCK_M]
        P = tl.exp(S - m_ij[:, None])  # [BLOCK_M, BLOCK_N]
        l_ij = tl.sum(P, axis=1)  # [BLOCK_M]
        
        # Update max
        m_new = tl.maximum(m_i, m_ij)
        
        # Compute scaling factors
        alpha = tl.exp(m_i - m_new)
        beta = tl.exp(m_ij - m_new)
        
        # Update sum
        l_new = alpha * l_i + beta * l_ij
        
        # Update output
        # Scale old output
        O_i = O_i * (alpha * l_i)[:, None]
        # Add new contribution: beta * P @ V
        PV = tl.dot(P.to(V_block.dtype), V_block)  # [BLOCK_M, BLOCK_D]
        O_i = O_i + beta[:, None] * PV
        # Normalize
        O_i = O_i / l_new[:, None]
        
        # Update accumulators
        m_i = m_new
        l_i = l_new
    
    # Store output
    o_ptrs = (O_ptr + 
              batch_idx * stride_ob + 
              offs_m[:, None] * stride_om + 
              offs_d[None, :] * stride_od)
    o_mask = (offs_m[:, None] < seq_len) & (offs_d[None, :] < d_model)
    tl.store(o_ptrs, O_i.to(O_ptr.dtype.element_ty), mask=o_mask)

In [None]:
def flash_attention_triton(Q, K, V, causal=False):
    """
    FlashAttention using Triton kernel.
    
    Q, K, V: [batch, seq_len, d_model]
    """
    batch, seq_len, d_model = Q.shape
    scale = d_model ** -0.5
    
    # Output tensor
    O = torch.empty_like(Q)
    
    # Block sizes
    BLOCK_M = 64
    BLOCK_N = 64
    BLOCK_D = triton.next_power_of_2(d_model)
    
    # Grid
    num_q_blocks = (seq_len + BLOCK_M - 1) // BLOCK_M
    grid = (num_q_blocks, batch)
    
    # Launch kernel
    flash_attention_kernel[grid](
        Q, K, V, O,
        Q.stride(0), Q.stride(1), Q.stride(2),
        K.stride(0), K.stride(1), K.stride(2),
        V.stride(0), V.stride(1), V.stride(2),
        O.stride(0), O.stride(1), O.stride(2),
        seq_len, d_model,
        scale,
        BLOCK_M=BLOCK_M,
        BLOCK_N=BLOCK_N,
        BLOCK_D=BLOCK_D,
        CAUSAL=causal,
    )
    
    return O

In [None]:
# Test Triton implementation
if torch.cuda.is_available():
    torch.manual_seed(42)
    batch, seq_len, d = 2, 128, 64
    
    Q = torch.randn(batch, seq_len, d, device='cuda')
    K = torch.randn(batch, seq_len, d, device='cuda')
    V = torch.randn(batch, seq_len, d, device='cuda')
    
    # Standard attention (causal)
    scores = torch.bmm(Q, K.transpose(-2, -1)) / (d ** 0.5)
    mask = torch.triu(torch.ones(seq_len, seq_len, device='cuda'), diagonal=1).bool()
    scores = scores.masked_fill(mask, float('-inf'))
    weights = F.softmax(scores, dim=-1)
    output_std = torch.bmm(weights, V)
    
    # Triton FlashAttention
    output_flash = flash_attention_triton(Q, K, V, causal=True)
    
    print(f"Max difference: {(output_std - output_flash).abs().max():.2e}")
    print(f"Mean difference: {(output_std - output_flash).abs().mean():.2e}")
else:
    print("GPU not available. Triton test skipped.")

---
## Step 4: Benchmark and Analysis (20 min)

### Speed Comparison

In [None]:
def benchmark_attention_implementations(seq_lengths, d_model=64, batch=4, num_runs=100):
    """Benchmark different attention implementations."""
    results = []
    
    for seq_len in seq_lengths:
        Q = torch.randn(batch, seq_len, d_model, device='cuda')
        K = torch.randn(batch, seq_len, d_model, device='cuda')
        V = torch.randn(batch, seq_len, d_model, device='cuda')
        
        # Warmup
        for _ in range(10):
            _ = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
            _ = flash_attention_triton(Q, K, V, causal=True)
        torch.cuda.synchronize()
        
        # Benchmark PyTorch SDPA
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(num_runs):
            _ = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
        torch.cuda.synchronize()
        pytorch_time = (time.perf_counter() - start) / num_runs * 1000
        
        # Benchmark our FlashAttention
        torch.cuda.synchronize()
        start = time.perf_counter()
        for _ in range(num_runs):
            _ = flash_attention_triton(Q, K, V, causal=True)
        torch.cuda.synchronize()
        triton_time = (time.perf_counter() - start) / num_runs * 1000
        
        results.append({
            'seq_len': seq_len,
            'pytorch_ms': pytorch_time,
            'triton_ms': triton_time,
            'ratio': triton_time / pytorch_time
        })
        
        del Q, K, V
        torch.cuda.empty_cache()
    
    return results

if torch.cuda.is_available():
    print("Speed Benchmark: PyTorch SDPA vs Our Triton FlashAttention")
    print("=" * 65)
    
    results = benchmark_attention_implementations(
        seq_lengths=[128, 256, 512, 1024, 2048],
        d_model=64,
        batch=4
    )
    
    print(f"{'Seq Len':>10} {'PyTorch SDPA':>15} {'Our Triton':>15} {'Ratio':>10}")
    print("-" * 55)
    for r in results:
        print(f"{r['seq_len']:>10} {r['pytorch_ms']:>13.3f}ms {r['triton_ms']:>13.3f}ms {r['ratio']:>9.2f}x")
    
    print("\nNote: PyTorch SDPA uses optimized FlashAttention internally.")
    print("Our implementation is for learning; production should use torch.nn.functional.scaled_dot_product_attention")
else:
    print("GPU not available for benchmarking.")

### Memory Comparison

In [None]:
def memory_benchmark(seq_lengths, d_model=64, batch=1):
    """Compare memory usage."""
    results = []
    
    for seq_len in seq_lengths:
        # Standard attention (materializes N×N matrix)
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        Q = torch.randn(batch, seq_len, d_model, device='cuda')
        K = torch.randn(batch, seq_len, d_model, device='cuda')
        V = torch.randn(batch, seq_len, d_model, device='cuda')
        
        # Force standard computation path by computing scores explicitly
        scores = torch.bmm(Q, K.transpose(-2, -1)) / (d_model ** 0.5)
        weights = F.softmax(scores, dim=-1)
        _ = torch.bmm(weights, V)
        torch.cuda.synchronize()
        standard_mem = torch.cuda.max_memory_allocated() / 1e6
        
        del scores, weights, Q, K, V
        
        # FlashAttention
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        Q = torch.randn(batch, seq_len, d_model, device='cuda')
        K = torch.randn(batch, seq_len, d_model, device='cuda')
        V = torch.randn(batch, seq_len, d_model, device='cuda')
        
        _ = flash_attention_triton(Q, K, V, causal=True)
        torch.cuda.synchronize()
        flash_mem = torch.cuda.max_memory_allocated() / 1e6
        
        del Q, K, V
        
        # Theoretical attention matrix size
        attn_matrix_mb = batch * seq_len * seq_len * 4 / 1e6  # FP32
        
        results.append({
            'seq_len': seq_len,
            'standard_mb': standard_mem,
            'flash_mb': flash_mem,
            'attn_matrix_mb': attn_matrix_mb,
            'savings': standard_mem / flash_mem if flash_mem > 0 else float('inf')
        })
    
    return results

if torch.cuda.is_available():
    print("\nMemory Benchmark: Standard vs FlashAttention")
    print("=" * 70)
    
    mem_results = memory_benchmark([128, 256, 512, 1024, 2048, 4096])
    
    print(f"{'Seq Len':>10} {'Standard':>12} {'Flash':>12} {'Attn Matrix':>14} {'Savings':>10}")
    print("-" * 60)
    for r in mem_results:
        print(f"{r['seq_len']:>10} {r['standard_mb']:>10.1f}MB {r['flash_mb']:>10.1f}MB "
              f"{r['attn_matrix_mb']:>12.1f}MB {r['savings']:>9.1f}x")
else:
    print("GPU not available.")

---
## Step 5: Verify — Final Quiz & Reflection (10 min)

### Week 3 Comprehensive Quiz

In [None]:
def check_answer(q, your_answer, correct):
    if your_answer == correct:
        print(f"✓ {q}")
    else:
        print(f"✗ {q}\n  Your: {your_answer}, Correct: {correct}")

print("Week 3 Final Quiz")
print("=" * 50)

# Q1
check_answer(
    "Q1: Dot product measures ___ between vectors",
    'similarity',  # Your answer
    'similarity'
)

# Q2
check_answer(
    "Q2: exp(x) overflows in FP16 at approximately x =",
    11,  # Your answer
    11
)

# Q3
check_answer(
    "Q3: The max-subtraction trick works because softmax(x-c) = softmax(x)",
    True,  # Your answer
    True
)

# Q4
check_answer(
    "Q4: Standard attention memory complexity is",
    'O(N^2)',  # Your answer
    'O(N^2)'
)

# Q5
check_answer(
    "Q5: Online softmax enables streaming because it can ___ when max changes",
    'rescale',  # Your answer
    'rescale'
)

# Q6
check_answer(
    "Q6: FlashAttention achieves memory complexity of",
    'O(N)',  # Your answer
    'O(N)'
)

# Q7
check_answer(
    "Q7: Tiling enables attention to fit in fast ___ memory",
    'SRAM',  # Your answer (or 'shared')
    'SRAM'
)

### Reflection: What We Built

In one week, we went from basic dot products to a working FlashAttention implementation:

```
Day 1: a·b = Σaᵢbᵢ (similarity measurement)
  ↓
Day 2: exp(x) → overflow in FP16 at x≈11
  ↓
Day 3: softmax(x-max) = softmax(x) (stability trick)
  ↓
Day 4: Attention = softmax(QK^T/√d)V (O(N²) memory problem)
  ↓
Day 5: Online softmax (streaming with rescaling)
  ↓
Day 6: Tiled computation (SRAM utilization)
  ↓
Day 7: FlashAttention (O(N) memory, fused kernel)
```

### Key Takeaways

1. **Numerical stability matters:** Understanding FP limits prevents silent failures
2. **Algorithms can reduce memory:** O(N²) → O(N) is possible with clever bookkeeping
3. **Hardware awareness:** SRAM vs HBM determines kernel design
4. **Fusion reduces I/O:** Combining operations saves memory bandwidth

---

## Summary: Week 3 Complete

| Day | Topic | Key Insight |
|-----|-------|-------------|
| 1 | Dot Product | Similarity = a·b = \|a\|\|b\|cos(θ) |
| 2 | Softmax Problem | exp(x) overflows at x≈11 (FP16) |
| 3 | Stable Softmax | softmax(x-c) = softmax(x) |
| 4 | Full Attention | O(N²) memory from N×N matrix |
| 5 | Online Softmax | Rescale when max changes |
| 6 | Tiled Attention | Process in SRAM-sized blocks |
| 7 | FlashAttention | O(N) memory, fused kernel |

**Next Week:** Quantization and production deployment — taking our kernels from learning exercises to real-world performance.

---

## References

- [FlashAttention Paper (Dao et al., 2022)](https://arxiv.org/abs/2205.14135)
- [FlashAttention-2 Paper (Dao, 2023)](https://arxiv.org/abs/2307.08691)
- [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)
- [Online Softmax (Milakov & Gimelshein, 2018)](https://arxiv.org/abs/1805.02867)

**Interactive Reference:** [attention-math.html](../attention-math.html) — Full attention visualization and calculators