# Week 3, Day 6: Tiled Attention — Block-by-Block on GPU

**Time:** ~1 hour

**Goal:** Apply online softmax to compute attention in tiles, leveraging fast SRAM on GPU.

## The Challenge

We have the online softmax algorithm from yesterday. Now we need to:
1. **Tile** the computation to fit in GPU shared memory (SRAM)
2. **Fuse** the operations to minimize memory traffic
3. **Handle both Q and K/V tiling** for the complete FlashAttention algorithm

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time

np.set_printoptions(precision=4, suppress=True)
torch.set_printoptions(precision=4, sci_mode=False)

---
## Step 1: The Challenge — GPU Memory Hierarchy (5 min)

### GPU Memory Hierarchy Recap

| Memory Type | Size | Bandwidth | Latency |
|-------------|------|-----------|--------|
| Registers | ~256KB per SM | N/A | ~1 cycle |
| Shared Memory (SRAM) | 64-228KB per SM | ~19 TB/s | ~20 cycles |
| L2 Cache | 40-60MB | ~8 TB/s | ~200 cycles |
| HBM (Global) | 40-80GB | ~2 TB/s | ~400 cycles |

**Key insight:** We want to:
1. Load Q, K, V tiles into **shared memory** (fast)
2. Compute attention **within** shared memory
3. Only write the final output to HBM

In [None]:
# Visualize the tiling strategy
def visualize_tiling(seq_len, block_q, block_kv):
    """Show how attention is tiled."""
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Draw the full attention matrix
    ax.set_xlim(0, seq_len)
    ax.set_ylim(seq_len, 0)  # Flip y-axis
    
    # Draw grid lines for tiles
    for i in range(0, seq_len + 1, block_q):
        ax.axhline(y=i, color='blue', linewidth=0.5, alpha=0.5)
    for j in range(0, seq_len + 1, block_kv):
        ax.axvline(x=j, color='red', linewidth=0.5, alpha=0.5)
    
    # Highlight one Q block processing all K blocks
    q_block_idx = 1
    for kv_idx in range(seq_len // block_kv):
        rect = plt.Rectangle(
            (kv_idx * block_kv, q_block_idx * block_q),
            block_kv, block_q,
            fill=True, facecolor='green', alpha=0.3, edgecolor='green', linewidth=2
        )
        ax.add_patch(rect)
    
    ax.set_xlabel('K/V position (columns)')
    ax.set_ylabel('Q position (rows)')
    ax.set_title(f'Tiled Attention: seq_len={seq_len}, block_Q={block_q}, block_KV={block_kv}\n'
                 f'Green = one Q block iterating over all K/V blocks')
    
    # Add annotations
    num_q_blocks = seq_len // block_q
    num_kv_blocks = seq_len // block_kv
    ax.text(seq_len/2, -2, f'{num_kv_blocks} K/V blocks', ha='center', fontsize=10)
    ax.text(-2, seq_len/2, f'{num_q_blocks} Q blocks', ha='center', va='center', 
            rotation=90, fontsize=10)
    
    plt.tight_layout()
    plt.show()

visualize_tiling(seq_len=16, block_q=4, block_kv=4)

---
## Step 2: Explore — The Tiling Strategy (15 min)

### FlashAttention's Key Insight

Standard attention:
1. Compute all of QK^T → Store N×N matrix to HBM
2. Apply softmax → Read/write N×N from HBM
3. Multiply by V → Read N×N, store N×d

**Memory I/O:** O(N² + N²) = O(N²)

FlashAttention:
1. For each Q block:
   - For each K/V block:
     - Load Q, K, V tiles into SRAM
     - Compute partial attention in SRAM
     - Update running output using online softmax
2. Write final output to HBM

**Memory I/O:** O(N × d) = O(N) — linear!

In [None]:
def tiled_attention_reference(Q, K, V, block_q, block_kv):
    """
    Reference implementation of tiled attention.
    Processes Q in blocks, K/V in blocks, using online softmax.
    
    Q: [seq_len, d_k]
    K: [seq_len, d_k]
    V: [seq_len, d_v]
    """
    seq_len, d_k = Q.shape
    d_v = V.shape[1]
    scale = np.sqrt(d_k)
    
    # Output matrix
    O = np.zeros((seq_len, d_v))
    
    # Process Q in blocks
    for q_start in range(0, seq_len, block_q):
        q_end = min(q_start + block_q, seq_len)
        q_block_size = q_end - q_start
        
        # Get Q block
        Q_block = Q[q_start:q_end]  # [block_q, d_k]
        
        # Initialize per-query accumulators for this Q block
        m = np.full(q_block_size, float('-inf'))  # [block_q]
        l = np.zeros(q_block_size)                 # [block_q]
        O_block = np.zeros((q_block_size, d_v))    # [block_q, d_v]
        
        # Iterate over K/V blocks
        for kv_start in range(0, seq_len, block_kv):
            kv_end = min(kv_start + block_kv, seq_len)
            
            # Get K, V blocks
            K_block = K[kv_start:kv_end]  # [block_kv, d_k]
            V_block = V[kv_start:kv_end]  # [block_kv, d_v]
            
            # Compute attention scores for this tile
            # [block_q, d_k] @ [d_k, block_kv] = [block_q, block_kv]
            S = (Q_block @ K_block.T) / scale
            
            # Online softmax update
            m_block = S.max(axis=1)  # [block_q]
            m_new = np.maximum(m, m_block)
            
            # Correction factor for previous accumulator
            correction = np.exp(np.where(m == float('-inf'), 0, m - m_new))
            
            # Update l and O
            l = l * correction
            O_block = O_block * correction[:, np.newaxis]
            
            # Add this block's contribution
            exp_S = np.exp(S - m_new[:, np.newaxis])  # [block_q, block_kv]
            l = l + exp_S.sum(axis=1)
            O_block = O_block + exp_S @ V_block  # [block_q, d_v]
            
            m = m_new
        
        # Final normalization for this Q block
        O[q_start:q_end] = O_block / l[:, np.newaxis]
    
    return O

# Test
np.random.seed(42)
seq_len, d_k, d_v = 16, 8, 8

Q = np.random.randn(seq_len, d_k)
K = np.random.randn(seq_len, d_k)
V = np.random.randn(seq_len, d_v)

# Standard attention
scores = (Q @ K.T) / np.sqrt(d_k)
weights = np.exp(scores - scores.max(axis=1, keepdims=True))
weights = weights / weights.sum(axis=1, keepdims=True)
output_standard = weights @ V

# Tiled attention
output_tiled = tiled_attention_reference(Q, K, V, block_q=4, block_kv=4)

print(f"Max difference: {np.abs(output_standard - output_tiled).max():.2e}")

### Adding Causal Masking

For causal attention, we need to mask positions where query index < key index.

In [None]:
def tiled_causal_attention(Q, K, V, block_q, block_kv):
    """
    Tiled attention with causal masking.
    """
    seq_len, d_k = Q.shape
    d_v = V.shape[1]
    scale = np.sqrt(d_k)
    
    O = np.zeros((seq_len, d_v))
    
    for q_start in range(0, seq_len, block_q):
        q_end = min(q_start + block_q, seq_len)
        q_block_size = q_end - q_start
        
        Q_block = Q[q_start:q_end]
        
        m = np.full(q_block_size, float('-inf'))
        l = np.zeros(q_block_size)
        O_block = np.zeros((q_block_size, d_v))
        
        for kv_start in range(0, seq_len, block_kv):
            kv_end = min(kv_start + block_kv, seq_len)
            
            # Skip blocks entirely in the future (optimization)
            if kv_start > q_end - 1:
                break
            
            K_block = K[kv_start:kv_end]
            V_block = V[kv_start:kv_end]
            
            # Compute scores
            S = (Q_block @ K_block.T) / scale
            
            # Apply causal mask
            # Create mask where True means "mask this position"
            q_indices = np.arange(q_start, q_end)[:, np.newaxis]  # [block_q, 1]
            k_indices = np.arange(kv_start, kv_end)[np.newaxis, :]  # [1, block_kv]
            causal_mask = k_indices > q_indices  # [block_q, block_kv]
            
            S = np.where(causal_mask, float('-inf'), S)
            
            # Online softmax update
            m_block = np.where(np.all(causal_mask, axis=1), float('-inf'), 
                               np.max(np.where(causal_mask, float('-inf'), S), axis=1))
            m_new = np.maximum(m, m_block)
            
            correction = np.exp(np.where(m == float('-inf'), 0, m - m_new))
            
            l = l * correction
            O_block = O_block * correction[:, np.newaxis]
            
            exp_S = np.exp(np.where(causal_mask, float('-inf'), S) - m_new[:, np.newaxis])
            exp_S = np.where(np.isinf(exp_S) | np.isnan(exp_S), 0, exp_S)
            
            l = l + exp_S.sum(axis=1)
            O_block = O_block + exp_S @ V_block
            
            m = m_new
        
        # Avoid division by zero for rows that are fully masked
        l = np.where(l == 0, 1, l)
        O[q_start:q_end] = O_block / l[:, np.newaxis]
    
    return O

# Test causal
# Standard causal attention
scores = (Q @ K.T) / np.sqrt(d_k)
causal_mask = np.triu(np.ones((seq_len, seq_len)), k=1).astype(bool)
scores_masked = np.where(causal_mask, float('-inf'), scores)
weights = np.exp(scores_masked - scores_masked.max(axis=1, keepdims=True))
weights = weights / weights.sum(axis=1, keepdims=True)
output_standard_causal = weights @ V

# Tiled causal attention
output_tiled_causal = tiled_causal_attention(Q, K, V, block_q=4, block_kv=4)

print(f"Causal max difference: {np.abs(output_standard_causal - output_tiled_causal).max():.2e}")

---
## Step 3: The Concept — Memory I/O Analysis (10 min)

### Standard Attention I/O

```
S = QK^T                    # Write N×N to HBM
P = softmax(S)              # Read N×N, write N×N
O = PV                      # Read N×N + N×d, write N×d
```

Total HBM reads: O(N² + N² + N²) = O(N²)
Total HBM writes: O(N² + N² + N×d) = O(N²)

### Tiled (FlashAttention) I/O

```
For each Q block:
  Load Q block once: N/Bq × Bq × d = N×d
  For each K/V block:
    Load K, V blocks: N/Bkv × (Bkv × d + Bkv × d)
    Compute in SRAM (no HBM access)
  Write O block: Bq × d
```

Total: O(N × d) reads, O(N × d) writes

In [None]:
def io_analysis(seq_len, d_model, block_q, block_kv, dtype_bytes=2):
    """
    Analyze HBM I/O for standard vs tiled attention.
    """
    # Standard attention
    # Read Q, K for S=QK^T: 2 * N * d
    # Write S: N * N
    # Read S for softmax: N * N
    # Write P: N * N
    # Read P, V for O=PV: N * N + N * d
    # Write O: N * d
    standard_reads = (2 * seq_len * d_model + seq_len * seq_len + 
                     seq_len * seq_len + seq_len * d_model) * dtype_bytes
    standard_writes = (seq_len * seq_len + seq_len * seq_len + 
                      seq_len * d_model) * dtype_bytes
    standard_total = standard_reads + standard_writes
    
    # Tiled (FlashAttention)
    # Outer loop: N/Bq iterations
    # Each outer: read Q block (Bq * d), write O block (Bq * d)
    # Inner loop: N/Bkv iterations
    # Each inner: read K block (Bkv * d), read V block (Bkv * d)
    
    num_q_blocks = seq_len // block_q
    num_kv_blocks = seq_len // block_kv
    
    # Q is loaded once per Q block
    q_reads = num_q_blocks * block_q * d_model * dtype_bytes
    # K, V are loaded for each (Q block, KV block) pair
    kv_reads = num_q_blocks * num_kv_blocks * 2 * block_kv * d_model * dtype_bytes
    # O is written once per Q block
    o_writes = seq_len * d_model * dtype_bytes
    
    tiled_total = q_reads + kv_reads + o_writes
    
    return {
        'seq_len': seq_len,
        'standard_gb': standard_total / 1e9,
        'tiled_gb': tiled_total / 1e9,
        'reduction': standard_total / tiled_total if tiled_total > 0 else float('inf')
    }

print(f"{'Seq Len':>10} {'Standard':>12} {'Tiled':>12} {'Reduction':>12}")
print("-" * 50)

for seq_len in [256, 512, 1024, 2048, 4096, 8192]:
    stats = io_analysis(seq_len, d_model=128, block_q=64, block_kv=64)
    print(f"{stats['seq_len']:>10} {stats['standard_gb']:>10.3f}GB "
          f"{stats['tiled_gb']:>10.3f}GB {stats['reduction']:>10.1f}x")

---
## Step 4: Code It — PyTorch Tiled Attention (30 min)

### Full Implementation

In [None]:
def tiled_attention_torch(Q, K, V, block_q, block_kv, causal=False):
    """
    Tiled attention in PyTorch.
    
    Q, K, V: [batch, seq_len, d_model]
    """
    batch, seq_len, d_k = Q.shape
    d_v = V.shape[-1]
    scale = d_k ** 0.5
    
    O = torch.zeros_like(V)
    
    for q_start in range(0, seq_len, block_q):
        q_end = min(q_start + block_q, seq_len)
        q_block_size = q_end - q_start
        
        Q_block = Q[:, q_start:q_end, :]  # [batch, block_q, d_k]
        
        # Initialize accumulators
        m = torch.full((batch, q_block_size), float('-inf'), device=Q.device)
        l = torch.zeros((batch, q_block_size), device=Q.device)
        O_block = torch.zeros((batch, q_block_size, d_v), device=Q.device)
        
        kv_end_limit = q_end if causal else seq_len
        
        for kv_start in range(0, seq_len, block_kv):
            kv_end = min(kv_start + block_kv, seq_len)
            
            # Early exit for fully masked blocks
            if causal and kv_start >= q_end:
                break
            
            K_block = K[:, kv_start:kv_end, :]  # [batch, block_kv, d_k]
            V_block = V[:, kv_start:kv_end, :]  # [batch, block_kv, d_v]
            
            # Compute scores: [batch, block_q, block_kv]
            S = torch.bmm(Q_block, K_block.transpose(-2, -1)) / scale
            
            # Apply causal mask
            if causal:
                q_indices = torch.arange(q_start, q_end, device=Q.device).view(-1, 1)
                k_indices = torch.arange(kv_start, kv_end, device=Q.device).view(1, -1)
                mask = k_indices > q_indices  # [block_q, block_kv]
                S = S.masked_fill(mask.unsqueeze(0), float('-inf'))
            
            # Online softmax
            m_block = S.max(dim=-1).values  # [batch, block_q]
            m_new = torch.maximum(m, m_block)
            
            # Handle -inf case
            correction = torch.exp(torch.where(
                m == float('-inf'),
                torch.zeros_like(m),
                m - m_new
            ))
            
            l = l * correction
            O_block = O_block * correction.unsqueeze(-1)
            
            exp_S = torch.exp(S - m_new.unsqueeze(-1))
            exp_S = torch.where(torch.isinf(S), torch.zeros_like(exp_S), exp_S)
            
            l = l + exp_S.sum(dim=-1)
            O_block = O_block + torch.bmm(exp_S, V_block)
            
            m = m_new
        
        # Normalize and store
        l = torch.where(l == 0, torch.ones_like(l), l)
        O[:, q_start:q_end, :] = O_block / l.unsqueeze(-1)
    
    return O

# Test
batch, seq_len, d_model = 2, 32, 64
Q = torch.randn(batch, seq_len, d_model)
K = torch.randn(batch, seq_len, d_model)
V = torch.randn(batch, seq_len, d_model)

# Standard attention
scores = torch.bmm(Q, K.transpose(-2, -1)) / (d_model ** 0.5)
# Causal mask
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
scores_masked = scores.masked_fill(mask, float('-inf'))
weights = F.softmax(scores_masked, dim=-1)
output_standard = torch.bmm(weights, V)

# Tiled attention
output_tiled = tiled_attention_torch(Q, K, V, block_q=8, block_kv=8, causal=True)

print(f"Max difference: {(output_standard - output_tiled).abs().max():.2e}")

### Benchmark: Standard vs Tiled

In [None]:
def benchmark_attention(Q, K, V, num_runs=100):
    """Benchmark standard vs tiled attention."""
    # Standard attention
    def standard_attn():
        scores = torch.bmm(Q, K.transpose(-2, -1)) / (Q.shape[-1] ** 0.5)
        weights = F.softmax(scores, dim=-1)
        return torch.bmm(weights, V)
    
    # Tiled attention
    def tiled_attn():
        return tiled_attention_torch(Q, K, V, block_q=64, block_kv=64)
    
    # Warmup
    for _ in range(10):
        _ = standard_attn()
        _ = tiled_attn()
    
    if Q.is_cuda:
        torch.cuda.synchronize()
    
    # Benchmark standard
    start = time.perf_counter()
    for _ in range(num_runs):
        _ = standard_attn()
    if Q.is_cuda:
        torch.cuda.synchronize()
    standard_time = (time.perf_counter() - start) / num_runs * 1000
    
    # Benchmark tiled
    start = time.perf_counter()
    for _ in range(num_runs):
        _ = tiled_attn()
    if Q.is_cuda:
        torch.cuda.synchronize()
    tiled_time = (time.perf_counter() - start) / num_runs * 1000
    
    return standard_time, tiled_time

# Run benchmark on CPU (tiled will be slower due to Python overhead)
print("CPU Benchmark (illustrates memory pattern, not optimized):")
print("-" * 60)

for seq_len in [64, 128, 256, 512]:
    Q = torch.randn(1, seq_len, 64)
    K = torch.randn(1, seq_len, 64)
    V = torch.randn(1, seq_len, 64)
    
    std_time, tiled_time = benchmark_attention(Q, K, V, num_runs=10)
    print(f"seq_len={seq_len:4d}: standard={std_time:7.2f}ms, tiled={tiled_time:7.2f}ms")

print("\nNote: Tiled is slower in Python. Real gains come from Triton/CUDA implementation.")

### Memory Usage Comparison

In [None]:
if torch.cuda.is_available():
    print("GPU Memory Comparison:")
    print("-" * 60)
    
    d_model = 64
    
    for seq_len in [256, 512, 1024, 2048]:
        # Standard attention
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        Q = torch.randn(1, seq_len, d_model, device='cuda')
        K = torch.randn(1, seq_len, d_model, device='cuda')
        V = torch.randn(1, seq_len, d_model, device='cuda')
        
        scores = torch.bmm(Q, K.transpose(-2, -1)) / (d_model ** 0.5)
        weights = F.softmax(scores, dim=-1)
        _ = torch.bmm(weights, V)
        torch.cuda.synchronize()
        standard_mem = torch.cuda.max_memory_allocated() / 1e6
        
        # Tiled attention
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        Q = torch.randn(1, seq_len, d_model, device='cuda')
        K = torch.randn(1, seq_len, d_model, device='cuda')
        V = torch.randn(1, seq_len, d_model, device='cuda')
        
        _ = tiled_attention_torch(Q, K, V, block_q=64, block_kv=64)
        torch.cuda.synchronize()
        tiled_mem = torch.cuda.max_memory_allocated() / 1e6
        
        print(f"seq_len={seq_len:5d}: standard={standard_mem:8.1f}MB, tiled={tiled_mem:8.1f}MB, "
              f"savings={standard_mem/tiled_mem:.1f}x")
        
        del Q, K, V
else:
    print("GPU not available for memory comparison.")

---
## Step 5: Verify — Quiz & Reflection (10 min)

### Quiz

In [None]:
def check_answer(question, your_answer, correct_answer):
    if your_answer == correct_answer:
        print(f"✓ Correct! {question}")
    else:
        print(f"✗ Incorrect. Your answer: {your_answer}, Correct: {correct_answer}")

# Q1: Why does tiled attention reduce memory I/O?
# a) It uses compression
# b) It never stores the N×N attention matrix
# c) It uses smaller data types
# d) It skips some computations
q1_answer = 'b'
check_answer("Memory I/O reduction", q1_answer, 'b')

In [None]:
# Q2: In tiled attention, what determines how much SRAM is needed?
# a) Total sequence length
# b) Block sizes (block_q, block_kv)
# c) Number of attention heads
# d) Batch size
q2_answer = 'b'
check_answer("SRAM requirement", q2_answer, 'b')

In [None]:
# Q3: For causal attention, which K/V blocks can be skipped entirely?
# a) Blocks where kv_start < q_start
# b) Blocks where kv_start >= q_end
# c) All diagonal blocks
# d) None can be skipped
q3_answer = 'b'
check_answer("Causal block skipping", q3_answer, 'b')

In [None]:
# Q4: What is the I/O complexity of tiled attention?
# a) O(N²)
# b) O(N log N)
# c) O(N × d)
# d) O(d²)
q4_answer = 'c'
check_answer("I/O complexity", q4_answer, 'c')

### Reflection Questions

1. **Block size tradeoff:** Larger blocks mean more SRAM usage but fewer iterations. How would you choose optimal block sizes?

2. **K/V reuse:** In our implementation, K/V are loaded once per (Q block, KV block) pair. Can we do better?

3. **Parallelism:** Each Q block is independent. How does this affect GPU parallelization?

---

## Summary

| Concept | Key Insight |
|---------|------------|
| Tiling | Process attention in blocks that fit in SRAM |
| Fusion | Combine QK^T, softmax, and ×V into one kernel |
| Online softmax | Enables block-by-block processing with correct results |
| I/O reduction | O(N²) → O(N×d) by not materializing attention matrix |

**Tomorrow:** The complete FlashAttention algorithm with all optimizations.

---

**Interactive Reference:** [lessons/memory-hierarchy.html](../lessons/memory-hierarchy.html) — GPU Memory Hierarchy