# Standard Attention: A Memory-Bound Operation

This notebook explores the **standard attention mechanism** and demonstrates why it becomes a bottleneck for long sequences. We'll use a simple GPU simulator to make the memory hierarchy constraints tangible.

## Prerequisites

You should be familiar with:
- Matrix multiplication
- The transformer architecture (at a high level)
- Basic GPU concepts (the idea that GPUs have fast compute but memory bandwidth is limited)

## 1. The Attention Equation

Given input matrices **Q** (queries), **K** (keys), and **V** (values), attention computes:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) V$$

Where:
- $Q, K, V \in \mathbb{R}^{N \times d}$ — N is sequence length, d is head dimension
- $QK^T \in \mathbb{R}^{N \times N}$ — the **attention matrix** (this is the problem!)
- $\sqrt{d}$ — scaling factor to keep gradients stable

### The Quadratic Problem

Notice that $QK^T$ produces an $N \times N$ matrix. For a sequence of length 4096:

$$4096 \times 4096 = 16,777,216 \text{ floats}$$

At FP16 (2 bytes/float), that's **32 MB just for one attention matrix**. With multiple heads and layers, this adds up fast.

## 2. Setup: GPU Simulator

We'll use a simple simulator that models the key characteristics of GPU memory hierarchy.

### Quick Terminology

- **FLOP** (Floating Point Operation): A single math operation on decimal numbers — one addition, one multiplication, etc. 

  *Example*: A matrix multiply $C = A \times B$ where $A$ and $B$ are both 1000×1000 requires:
  - Each element of $C$ needs 1000 multiplications and 1000 additions (dot product of a row and column)
  - $C$ has 1000×1000 = 1M elements
  - Total: 1M elements × 2000 ops = **2 billion FLOPs**
  
  General formula: multiplying $(M \times K)$ by $(K \times N)$ costs $2 \times M \times N \times K$ FLOPs.

- **Cycle**: One "tick" of the GPU clock. A modern GPU runs at ~1-2 GHz, meaning 1-2 billion cycles per second. We use cycles as our unit of time because it lets us reason about *ratios* without worrying about actual clock speeds.

### Our Simulated GPUs

We model three generations of NVIDIA GPUs. Notice how newer GPUs become increasingly **compute-bound** — their compute throughput grows faster than memory bandwidth:

In [1]:
from src.gpu_sim import GPUSpec, Profiler, Tensor

# Compare three generations of GPUs
gpus = [GPUSpec.sim_v100(), GPUSpec.sim_a100(), GPUSpec.sim_h100()]

print(f"{'GPU':<20} | {'SRAM':>12} | {'Compute/HBM':>12} | {'Trend'}")
print("-" * 65)
for g in gpus:
    sram_kb = g.sram_size * g.bytes_per_float / 1024
    ratio = g.flop_rate / g.hbm_bandwidth
    print(f"{g.name:<20} | {sram_kb:>9.0f} KB | {ratio:>10.0f}x | ", end="")
    if ratio < 100:
        print("Less compute-bound")
    elif ratio < 200:
        print("Compute-bound")
    else:
        print("Very compute-bound")

print("\n→ Newer GPUs are MORE compute-bound, making memory optimization MORE important.")

GPU                  |         SRAM |  Compute/HBM | Trend
-----------------------------------------------------------------
Simulated V100       |       128 KB |         70x | Less compute-bound
Simulated A100       |       256 KB |        150x | Compute-bound
Simulated H100       |       384 KB |        300x | Very compute-bound

→ Newer GPUs are MORE compute-bound, making memory optimization MORE important.


In [2]:
# We'll use the A100 for our examples
gpu = GPUSpec.sim_a100()
print(f"Using: {gpu.name}")

Using: Simulated A100


## 3. Standard Attention: The Naive Implementation

Here's how standard attention works, step by step:

```
1. Load Q, K from HBM
2. Compute S = Q @ K.T        # N×N attention scores
3. Store S to HBM             # Can't fit in SRAM!
4. Load S from HBM
5. Compute P = softmax(S)     # N×N attention weights  
6. Store P to HBM             # Still N×N
7. Load P, V from HBM
8. Compute O = P @ V          # Final output
9. Store O to HBM
```

The problem: we **materialize the full N×N matrix** and shuttle it back and forth to HBM multiple times.

We'll use a `Tensor` class that wraps the profiler calls, giving us PyTorch-like syntax while tracking all the memory operations under the hood:

In [3]:
def standard_attention(prof: Profiler, N: int, d: int):
    """
    Simulate standard attention for a single head.
    
    The key problem: S and P are materialized in HBM, causing extra traffic.
    
    Args:
        prof: Profiler instance to track cycles and memory
        N: Sequence length
        d: Head dimension
    """
    # Allocate inputs (would be loaded from previous layer in practice)
    Q = Tensor((N, d), name="Q", profiler=prof)
    K = Tensor((N, d), name="K", profiler=prof)
    V = Tensor((N, d), name="V", profiler=prof)
    
    # Step 1: Compute attention scores
    # S = Q @ K.T produces an N×N matrix that must be stored in HBM
    S = Q @ K.T
    
    # Step 2: Apply softmax
    # P = softmax(S) produces another N×N matrix stored in HBM
    P = S.softmax()
    S.free()  # Free S - no longer needed
    
    # Step 3: Apply attention weights to values
    # O = P @ V produces the final N×d output
    O = P @ V
    P.free()  # Free P - no longer needed
    
    return O

# Enable verbose output to see memory operations
Tensor.verbose = True

## 4. Running the Simulation

Let's first try with a small sequence that fits in SRAM:

In [4]:
# Small sequence
N = 256
d = 64

print(f"Sequence length N = {N}, head dimension d = {d}")
print(f"Attention matrix S: {N}×{N} = {N*N:,} floats\n")

prof = Profiler(gpu, f"Standard Attention (N={N})")
standard_attention(prof, N, d)
prof.report()

Sequence length N = 256, head dimension d = 64
Attention matrix S: 256×256 = 65,536 floats

Allocate Q (256, 64)                          | HBM:   32.0 KB | SRAM:    0.0 KB (  0.0%)
Allocate K (256, 64)                          | HBM:   64.0 KB | SRAM:    0.0 KB (  0.0%)
Allocate V (256, 64)                          | HBM:   96.0 KB | SRAM:    0.0 KB (  0.0%)

>>> Q @ K.T
Allocate Q@K.T (256, 256)                     | HBM:  224.0 KB | SRAM:    0.0 KB (  0.0%)
    [HBM → SRAM] Load Q                           | HBM:  224.0 KB | SRAM:   32.0 KB ( 12.5%)
    [HBM → SRAM] Load K                           | HBM:  224.0 KB | SRAM:   64.0 KB ( 25.0%)
    [Compute] matmul → (256, 256)                 | HBM:  224.0 KB | SRAM:   64.0 KB ( 25.0%)
    [SRAM → HBM] Store Q@K.T                      | HBM:  224.0 KB | SRAM:   64.0 KB ( 25.0%)
    [SRAM] Clear working set                      | HBM:  224.0 KB | SRAM:    0.0 KB (  0.0%)

>>> softmax(Q@K.T)
Allocate softmax(Q@K.T) (256, 256)           

Now let's try a realistic sequence length. We'll disable verbose output since the operations are the same — just with much larger matrices:

In [5]:
# Disable verbose for the large test
Tensor.verbose = True

# Realistic sequence length
N = 4096
d = 64

print(f"Sequence length N = {N}, head dimension d = {d}")
print(f"Attention matrix S: {N}×{N} = {N*N:,} floats")
print(f"SRAM capacity: {gpu.sram_size:,} floats")
print(f"S is {N*N / gpu.sram_size:.0f}x larger than SRAM!\n")

prof = Profiler(gpu, f"Standard Attention (N={N})")
standard_attention(prof, N, d)
prof.report()

Sequence length N = 4096, head dimension d = 64
Attention matrix S: 4096×4096 = 16,777,216 floats
SRAM capacity: 131,072 floats
S is 128x larger than SRAM!

Allocate Q (4096, 64)                         | HBM:  512.0 KB | SRAM:    0.0 KB (  0.0%)
Allocate K (4096, 64)                         | HBM:    1.0 MB | SRAM:    0.0 KB (  0.0%)
Allocate V (4096, 64)                         | HBM:    1.5 MB | SRAM:    0.0 KB (  0.0%)

>>> Q @ K.T
Allocate Q@K.T (4096, 4096)                   | HBM:   33.5 MB | SRAM:    0.0 KB (  0.0%)
    [2D tiling] 4×4 = 16 tiles (1024×1024 each)   | HBM:   33.5 MB | SRAM:    0.0 KB (  0.0%)
    [Tile 1/16] A[0:1024,:] @ B[:,0:1024]         | HBM:   33.5 MB | SRAM:  256.0 KB (100.0%)
    [Tile 2/16] A[0:1024,:] @ B[:,1024:2048]      | HBM:   33.5 MB | SRAM:  256.0 KB (100.0%)
    [Tile 3/16] A[0:1024,:] @ B[:,2048:3072]      | HBM:   33.5 MB | SRAM:  256.0 KB (100.0%)
    ... (12 more tiles) ...
    [Tile 16/16] A[3072:4096,:] @ B[:,3072:4096]  | HBM:   33.5 MB

## 5. The HBM Traffic Problem

The issue isn't that standard attention *can't run* — it's that we're doing **unnecessary HBM round-trips**:

| Operation | What we wanted | What actually happens |
|-----------|---------------|----------------------|
| Compute Q@K.T | Keep result in SRAM | N×N too big → write to HBM |
| Softmax | Use Q@K.T from SRAM | In HBM → read it back |
| Compute P | Keep result in SRAM | N×N too big → write to HBM |
| P @ V | Use P from SRAM | In HBM → read it back |

Every time we read/write the N×N matrix, we pay the HBM bandwidth cost.

In [6]:
# Show HBM traffic breakdown
print("HBM Traffic Breakdown:")
print("-" * 70)

for op, name, size in prof.get_hbm_traffic_log():
    direction = "←" if op == "hbm_read" else "→"
    mb = size * gpu.bytes_per_float / 1024 / 1024
    print(f"  {direction} {op.split('_')[1].upper():>5}: {size:>12,} floats ({mb:>5.1f} MB)  {name}")

print("-" * 70)
print(f"\nNotice: The N×N attention matrix is written then immediately read back — twice!")
print(f"        That's {2 * N * N * 2:,} floats of unnecessary HBM traffic.")

HBM Traffic Breakdown:
----------------------------------------------------------------------
  ←  READ:      262,144 floats (  0.5 MB)  Q
  ←  READ:      262,144 floats (  0.5 MB)  K
  → WRITE:   16,777,216 floats ( 32.0 MB)  Q@K.T
  ←  READ:   16,777,216 floats ( 32.0 MB)  Q@K.T
  → WRITE:   16,777,216 floats ( 32.0 MB)  softmax(Q@K.T)
  ←  READ:   16,777,216 floats ( 32.0 MB)  softmax(Q@K.T)
  ←  READ:      262,144 floats (  0.5 MB)  V
  → WRITE:      262,144 floats (  0.5 MB)  softmax(Q@K.T)@V
----------------------------------------------------------------------

Notice: The N×N attention matrix is written then immediately read back — twice!
        That's 67,108,864 floats of unnecessary HBM traffic.


## Key Takeaways

1. **Standard attention is memory-bound**: Most time is spent on HBM traffic, not compute.

2. **The N×N matrices S and P cause the problem**: They're too large for SRAM, so they must be written to HBM and read back — twice each.

3. **HBM traffic scales quadratically**: `O(N²)` floats for S and P, vs only `O(Nd)` for Q, K, V, O.

4. **The compute is actually cheap**: The matmuls and softmax are fast once data is in SRAM.

---

**The FlashAttention Insight**: What if we never wrote S and P to HBM at all? 

By carefully tiling the computation, FlashAttention keeps the intermediate attention scores in SRAM and only writes the final output O to HBM. This reduces HBM traffic from `O(N²)` to `O(Nd)` — a massive win for long sequences.

## 6. FlashAttention: Eliminating the N×N Round-trips

FlashAttention's key insight: **never materialize S or P in HBM**. Instead:
1. Tile Q, K, V into blocks that fit in SRAM
2. For each block: compute local attention scores, apply softmax, accumulate into output
3. Only write the final O to HBM

Let's simulate this to see the HBM traffic savings:

In [None]:
def flash_attention(prof: Profiler, N: int, d: int, block_size: int = 64):
    """
    Simulate FlashAttention's memory access pattern.
    
    Key difference from standard attention:
    - S and P are NEVER written to HBM
    - Everything is computed in tiles that fit in SRAM
    - Only Q, K, V are read and O is written
    
    Args:
        prof: Profiler instance
        N: Sequence length  
        d: Head dimension
        block_size: Tile size (B_r = B_c in the paper)
    """
    # Allocate inputs and output in HBM
    prof.allocate_hbm(N * d, "Q")
    prof.allocate_hbm(N * d, "K")
    prof.allocate_hbm(N * d, "V")
    prof.allocate_hbm(N * d, "O")
    
    B = block_size
    num_blocks = (N + B - 1) // B
    
    if Tensor.verbose:
        print(f"FlashAttention: {num_blocks}×{num_blocks} = {num_blocks**2} tiles of {B}×{B}")
        print(f"SRAM per tile: Q_block({B}×{d}) + K_block({B}×{d}) + V_block({B}×{d}) + S_block({B}×{B}) + O_block({B}×{d})")
        sram_needed = B*d + B*d + B*d + B*B + B*d
        print(f"             = {sram_needed:,} floats = {sram_needed * 2 / 1024:.1f} KB")
        print()
    
    # FlashAttention outer loop: iterate over K, V blocks
    for j in range(num_blocks):
        kv_start = j * B
        kv_end = min(kv_start + B, N)
        kv_size = kv_end - kv_start
        
        # Load K_j and V_j once per outer iteration
        prof.load_from_hbm(kv_size * d, f"K[{kv_start}:{kv_end}]")
        prof.load_from_hbm(kv_size * d, f"V[{kv_start}:{kv_end}]")
        
        # Inner loop: iterate over Q blocks
        for i in range(num_blocks):
            q_start = i * B
            q_end = min(q_start + B, N)
            q_size = q_end - q_start
            
            # Load Q_i block
            prof.load_from_hbm(q_size * d, f"Q[{q_start}:{q_end}]")
            
            # Track SRAM usage for this tile
            tile_sram = q_size * d + kv_size * d + kv_size * d + q_size * kv_size + q_size * d
            prof.sram_push(tile_sram, f"tile[{i},{j}]")
            
            # Compute S_ij = Q_i @ K_j.T (stays in SRAM!)
            prof.matmul(q_size, kv_size, d, f"S[{i},{j}] = Q_i @ K_j.T")
            
            # Compute P_ij = softmax(S_ij) (stays in SRAM!)
            prof.elementwise(q_size * kv_size, 5, f"P[{i},{j}] = softmax(S)")
            
            # Compute O_i += P_ij @ V_j (accumulate in SRAM)
            prof.matmul(q_size, d, kv_size, f"O[{i}] += P @ V_j")
            
            prof.sram_pop(tile_sram, f"tile[{i},{j}]")
    
    # Write final output O to HBM (only once!)
    prof.store_to_hbm(N * d, "O")
    
    if Tensor.verbose:
        print(f"Done! S and P never touched HBM.")

# Compare standard vs flash attention
print("="*70)
print("COMPARISON: Standard Attention vs FlashAttention")
print("="*70)

N, d = 4096, 64

# Standard attention
Tensor.verbose = False
prof_std = Profiler(gpu, "Standard Attention")
standard_attention(prof_std, N, d)

# FlashAttention  
prof_flash = Profiler(gpu, "FlashAttention")
flash_attention(prof_flash, N, d, block_size=64)

print(f"\n{'Metric':<30} {'Standard':>15} {'FlashAttention':>15} {'Savings':>12}")
print("-" * 75)

std_traffic = prof_std.total_hbm_reads + prof_std.total_hbm_writes
flash_traffic = prof_flash.total_hbm_reads + prof_flash.total_hbm_writes
print(f"{'HBM Traffic (floats)':<30} {std_traffic:>15,} {flash_traffic:>15,} {(1 - flash_traffic/std_traffic)*100:>11.1f}%")

std_mb = std_traffic * 2 / 1024 / 1024
flash_mb = flash_traffic * 2 / 1024 / 1024
print(f"{'HBM Traffic (MB)':<30} {std_mb:>15.1f} {flash_mb:>15.1f} {(1 - flash_mb/std_mb)*100:>11.1f}%")

print(f"{'Peak HBM (floats)':<30} {prof_std.peak_hbm_usage:>15,} {prof_flash.peak_hbm_usage:>15,}")

std_pct = prof_std.cycles_hbm / (prof_std.cycles_hbm + prof_std.cycles_compute) * 100
flash_pct = prof_flash.cycles_hbm / (prof_flash.cycles_hbm + prof_flash.cycles_compute) * 100
print(f"{'Time on HBM (%)':<30} {std_pct:>14.1f}% {flash_pct:>14.1f}%")

print("\n→ FlashAttention eliminates O(N²) HBM traffic for S and P!")
print(f"  Standard: reads/writes {N}×{N} attention matrix twice = {4*N*N:,} floats")
print(f"  Flash: only reads Q,K,V and writes O = {4*N*d:,} floats (for the attention part)")

## Key Takeaways

1. **Standard attention is memory-bound**: Even with 2D tiling, 86% of time is spent on HBM traffic.

2. **The N×N matrices S and P are the bottleneck**: They must be written to HBM and read back twice.

3. **FlashAttention eliminates this**: By fusing operations and keeping S, P in SRAM:
   - **90% less HBM traffic**
   - **97% less peak HBM usage** (no N×N allocation)
   - **Shifts from memory-bound to compute-bound**

4. **The compute is the same**: Both do identical FLOPs — FlashAttention just accesses memory smarter.

---

**Next**: In the FlashAttention notebook, we'll implement the actual tiled algorithm with online softmax.