# How to Write a Flash Attention Kernel in Pallas
## Introduction

In the previous posts in this series, we learnt how to write a [matrix multiplication kernel](https://blog.vikrampawar.com/pallas-matmul.html) and a [softmax kernel](https://blog.vikrampawar.com/pallas-softmax.html) using Pallas. Building on these, we will design a fused self-attention kernel. Self-attention is a major bottleneck in deep learning architectures due to its O(N²) memory requirement. In a naive implementation, materializing the full N×N attention matrix requires O(N²) memory bandwidth. This creates a severe bottleneck on GPUs, as the time to transfer this data from high-bandwidth memory (HBM) dominates over the actual computation time.

For our experiment, we will use the NVIDIA RTX 4000 Ada Generation GPU. This is fairly powerful (and cheap!) modern GPU architecture that is appropriate to demonstrate Flash Attention.

![rtx-sys-diag](rtx_4000_ada_system_diagram.png)

## Self Attention

Mathematically, the self-attention operation is `softmax(QK^T / √d) @ V`, where Q is a set of queries, K is a set of keys and V is a set of values.

The Queries (Q) are usually a tensor of shape `(B, H, T, D)`, where `B` is batch size, `H` is number of heads, `T` is sequence length, and `D` is the embedding dimension (or head dimension). Each query vector represents a position in the sequence that attends to keys. Keys are used to compute attention scores with queries while values are the information retrieved based on attention weights.
Scaling by $1/\sqrt{d}$ stabilizes the variance, ensuring the softmax behaves similarly across different embedding dimensions, improving optimization and generalization.

To understand the basics of self-attention in more detail, [here's](https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html) an excellent blogpost by Sebastian Raschka.

### Why it is Slow

The naive implementation must handle a T×T attention matrix that quickly exceeds GPU shared memory (SMEM) capacity (typically 48-96KB).

| T | T×T matrix (bf16) | Fits in SRAM? |
|---|------------------|---------------|
| 128 | 32KB | ✅ Yes (barely) |
| 256 | 128KB | ❌ No |
| 512 | 512KB | ❌ No |
| 1024 | 2MB | ❌ No |
| 2048 | 8MB | ❌ No |

For T≥256, the attention matrix cannot fit in SMEM, so JAX's naive implementation must materialize it in HBM (high-bandwidth memory). Here's the actual data flow:

1. Read Q, K, V from HBM
2. Compute Q @ K^T in tiles (for SMEM fit), but write each tile to HBM
3. Read back the full T×T matrix from HBM for softmax
4. Write softmax output to HBM
5. Read softmax output from HBM for P @ V
6. Write final output to HBM


In [14]:
import jax
import jax.numpy as jnp

@jax.jit
def naive_attention(q, k, v):
    d = q.shape[-1]
    logits = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(d)
    probs = jax.nn.softmax(logits, axis=-1)
    o = jnp.einsum('bhqk,bhkd->bhqd', probs, v)
    return o


The T×T matrix is written to HBM (step 2), read back (step 3), written again after softmax (step 4), and read back again (step 5). Even though each operation uses tiling internally, intermediate results live in HBM between operations.

For T=1024 with bfloat16, that's 2×1024×1024≈2MB per head written and read multiple times. HBM bandwidth (~300-900 GB/s) is orders of magnitude slower than SRAM bandwidth (~10-30 TB/s), so these transfers dominate execution time.

## The Flash Attention Algorithm

The key insight is that we can compute the attention output without ever materializing the full T×T attention matrix in HBM. Instead, we process it in small tiles that fit in SRAM, discarding each tile after using it.

**Flash attention data flow:**
- Load Q tile, K tile, V tile into SRAM
- Compute Q @ K^T → T×T tile in SRAM
- Compute online softmax on the tile (using running statistics)
- Multiply with V and accumulate into output in SRAM
- **Discard the T×T tile** (never written to HBM!)
- Repeat for all K/V tiles, accumulating into same output
- Write final output (T×D) to HBM once

**Algorithm steps:**
- Tile Q (outer parallel loop)
- Tile K, V (inner sequential loop)
- Maintain running max `m`, sum `l`, and output accumulator `o`
- Correction factor when max changes (online softmax)
- Final normalization
- Python reference implementation (like your `online_softmax` function)

In [None]:
from functools import partial
import math

from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu

INTERPRET_MODE = False  # Set to False on GPU

BLOCK_R = 64  # Block size for rows (Q blocks)
BLOCK_C = 64  # Block size for columns (KV blocks)
NUM_WARPS = 4
NUM_STAGES = 3

In [None]:
def flash_attention_fwd_kernel(q_ref, k_ref, v_ref, o_ref, logsumexp_ref, *, scale, num_k_blocks):
    """Flash attention forward kernel.
    
    Precision strategy:
    - Load Q, K, V as bfloat16 (saves memory bandwidth)
    - Matmuls use bf16 tensor cores with float32 accumulation
    - All intermediate math (softmax, corrections) in float32
    - Accumulators (o_reg, max_reg, l_reg) in float32
    - Store outputs as bfloat16
    """
    q_reg = plgpu.load(q_ref.at[0, :, :])  # Keep as bf16
    o_reg = jnp.zeros(q_reg.shape, jnp.float32)  # float32 accumulator
    max_reg = jnp.full((BLOCK_R,), -jnp.inf, dtype=jnp.float32)
    l_reg = jnp.zeros((BLOCK_R,), dtype=jnp.float32)

    def body(t, args):
        max_reg, l_reg, o_reg = args
        idx = pl.dslice(t * BLOCK_C, BLOCK_C)
        k_blk = plgpu.load(k_ref.at[0, idx, :])  # Keep as bf16
        v_blk = plgpu.load(v_ref.at[0, idx, :])  # Keep as bf16
        
        # Q @ K^T: bf16 inputs, float32 accumulation via tensor cores
        s_blk = pl.dot(q_reg, k_blk, trans_b=True) / scale  # float32 output
        
        # Softmax math in float32
        max_blk = jnp.maximum(max_reg, jnp.max(s_blk, axis=-1))
        s_blk = jnp.exp(s_blk - max_blk[:, None])
        l_blk = jnp.sum(s_blk, axis=-1)
        
        # P @ V: cast P back to bf16 for tensor core efficiency
        o_blk = pl.dot(s_blk.astype(v_blk.dtype), v_blk)
        
        # Online softmax correction (float32)
        return (max_blk, 
                l_reg * jnp.exp(max_reg - max_blk) + l_blk, 
                o_reg * jnp.exp(max_reg - max_blk)[:, None] + o_blk)

    max_reg, l_reg, o_reg = jax.lax.fori_loop(0, num_k_blocks, body, (max_reg, l_reg, o_reg))
    logsumexp_reg = max_reg + jnp.log(l_reg)
    o_reg = o_reg / l_reg[:, None]
    
    # Store as bf16
    plgpu.store(o_ref.at[0, :, :], o_reg.astype(o_ref.dtype))
    plgpu.store(logsumexp_ref.at[0, :], logsumexp_reg.astype(logsumexp_ref.dtype))

In [4]:
@jax.jit
def flash_attention_fwd(q, k, v):
    """Flash attention forward pass."""
    B, H, T, C = q.shape
    B_flat = B*H
    q_flat = q.reshape(-1, T, C)
    k_flat = k.reshape(-1, T, C)
    v_flat = v.reshape(-1, T, C)
    scale = math.sqrt(C)
    num_k_blocks = pl.cdiv(T, BLOCK_C)
    grid = (B_flat, pl.cdiv(T, BLOCK_R))

    out_flat, logsumexp = pl.pallas_call(
        partial(flash_attention_fwd_kernel, scale=scale, num_k_blocks=num_k_blocks),
        out_shape=[
            jax.ShapeDtypeStruct(q_flat.shape, q_flat.dtype),
            jax.ShapeDtypeStruct((B*H, T), q_flat.dtype)
        ],
        grid=grid,
        in_specs=[
            pl.BlockSpec((1, BLOCK_R, C), lambda b, t: (b, t, 0)),
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0)),
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0))
        ],
        out_specs=[
            pl.BlockSpec((1, BLOCK_R, C), lambda b, t: (b, t, 0)),
            pl.BlockSpec((1, BLOCK_R), lambda b, t: (b, t))
        ],
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES
        )
    )(q_flat, k_flat, v_flat)
    out = out_flat.reshape(q.shape)
    logsumexp = logsumexp.reshape(B, H, T)
    return out, logsumexp

## Choosing Kernel and Embedding Size

### Why the Kernel Crashes with Embedding Dimension > 64

The flash attention kernel in this implementation uses hardcoded block sizes (`BLOCK_R = 64` and `BLOCK_C = 64`). This becomes problematic as the embedding dimension (head dimension `C`) increases due to **GPU shared memory (SRAM) overflow**.

### Memory Calculation

Let's calculate the shared memory usage for the forward pass kernel. The following tensors reside in shared memory:

**Active tensors in SRAM during kernel execution:**
- `q_reg`: Shape `(BLOCK_R, C)` = `(64, C)`
- `k_blk`: Shape `(BLOCK_C, C)` = `(64, C)`
- `v_blk`: Shape `(BLOCK_C, C)` = `(64, C)`
- `o_reg`: Shape `(BLOCK_R, C)` = `(64, C)`
- `s_blk`: Shape `(BLOCK_R, BLOCK_C)` = `(64, 64)`

**Total elements:** `256 × C + 4,096`

**Memory with float32 (4 bytes/element):**

| Embedding (C) | Total Elements | Memory |
|---------------|----------------|--------|
| 64 | 256×64 + 4,096 = 20,480 | 81,920 bytes ≈ **80KB** |
| 128 | 256×128 + 4,096 = 36,864 | 147,456 bytes ≈ **144KB** |
| 256 | 256×256 + 4,096 = 69,632 | 278,528 bytes ≈ **272KB** |

### The Problem

NVIDIA GPUs typically have **48-96KB of shared memory per streaming multiprocessor (SM)**. As shown in the table:
- **C=64**: 80KB - fits within most GPU shared memory limits
- **C=128**: 144KB - exceeds shared memory limit (crash!)
- **C=256**: 272KB - far exceeds shared memory limit (crash!)

### How the Reference Implementation Handles This

The official JAX reference implementation (`pallas_flash_attn_ref.py`) addresses this issue with:

1. **Power-of-2 padding for efficient memory alignment:**
```python
head_dim_padded = pl.next_power_of_2(head_dim)
```

2. **Dynamic block sizing based on head dimension:**
```python
num_warps_ = 4 if head_dim <= 64 else 8
```

3. **Configurable `BlockSizes` dataclass** instead of hardcoded values:
```python
@dataclasses.dataclass(frozen=True, slots=True)
class BlockSizes:
    block_q: int
    block_k: int
    block_q_dkv: int | None = None
    block_kv_dkv: int | None = None
    block_q_dq: int | None = None
    block_kv_dq: int | None = None
```

### Potential Fixes

To support larger embedding dimensions, you can:

1. **Reduce block sizes dynamically** when C is large:
```python
BLOCK_R = BLOCK_C = min(64, max(32, 16384 // C))  # Aim for ~64KB
```

2. **Add head dimension padding** like the reference implementation

3. **Make block sizes configurable** with sensible defaults based on C:
```python
def get_block_sizes(head_dim: int) -> tuple[int, int]:
    if head_dim <= 64:
        return 64, 64
    elif head_dim <= 128:
        return 64, 32
    else:
        return 32, 32
```

The key insight is that shared memory is the limiting factor for GPU kernels, and block sizes must be chosen carefully to stay within hardware constraints while maintaining good memory access patterns.

## Backward Pass

The backward pass computes gradients dQ, dK, dV given the upstream gradient dO. The key insight is that we can recompute the attention weights P from the stored logsumexp values rather than storing them:

$$P = \exp(QK^T / \sqrt{d} - \text{logsumexp})$$

We use three separate kernels to avoid atomic operations:
1. **Preprocess**: Compute $D = \text{rowsum}(O \odot dO)$ which is used in the softmax backward
2. **dK/dV kernel**: Outer loop over KV blocks, inner loop over Q blocks
3. **dQ kernel**: Outer loop over Q blocks, inner loop over KV blocks

The gradient formulas are:
- $dP = dO \cdot V^T$
- $dS = P \odot (dP - D) / \sqrt{d}$ (softmax backward with scaling)
- $dQ = dS \cdot K$
- $dK = dS^T \cdot Q$  
- $dV = P^T \cdot dO$

In [None]:
# Kernel 1: Preprocess - compute D = rowsum(O * dO)
def flash_attention_bwd_preprocess_kernel(o_ref, do_ref, d_ref):
    """Compute D = rowsum(O * dO) for backward pass.
    
    Precision: Load O, dO as bf16, cast product to float32 before sum.
    """
    o_reg = plgpu.load(o_ref)   # Keep as bf16
    do_reg = plgpu.load(do_ref) # Keep as bf16
    # Element-wise multiply in bf16, reduce in float32
    d_reg = jnp.sum((o_reg * do_reg).astype(jnp.float32), axis=-1)
    plgpu.store(d_ref, d_reg.astype(d_ref.dtype))


def flash_attention_bwd_preprocess(o_flat, do_flat):
    """Preprocess for backward: compute D = rowsum(O * dO)."""
    B_flat, T, C = o_flat.shape
    grid = (B_flat, pl.cdiv(T, BLOCK_R))

    d_flat = pl.pallas_call(
        flash_attention_bwd_preprocess_kernel,
        out_shape=jax.ShapeDtypeStruct((B_flat, T), o_flat.dtype),  # Match input dtype
        grid=grid,
        in_specs=[
            pl.BlockSpec((1, BLOCK_R, C), lambda b, t: (b, t, 0)),
            pl.BlockSpec((1, BLOCK_R, C), lambda b, t: (b, t, 0)),
        ],
        out_specs=pl.BlockSpec((1, BLOCK_R), lambda b, t: (b, t)),
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(num_warps=NUM_WARPS, num_stages=NUM_STAGES)
    )(o_flat, do_flat)
    return d_flat

In [None]:

# Kernel 2: dK/dV - outer loop over KV blocks, inner loop over Q blocks
def flash_attention_bwd_dkv_kernel(
    q_ref, k_ref, v_ref, do_ref, logsumexp_ref, d_ref,
    dk_ref, dv_ref,
    *, scale, num_q_blocks
):
    """Compute dK and dV gradients.
    
    Precision strategy:
    - Load tensors as bf16 (saves bandwidth)
    - Matmuls use bf16 tensor cores with float32 accumulation
    - Intermediate math (softmax recompute) in float32
    - Accumulators dk_acc, dv_acc in float32
    """
    k_reg = plgpu.load(k_ref.at[0, :, :])  # bf16
    v_reg = plgpu.load(v_ref.at[0, :, :])  # bf16

    dk_acc = jnp.zeros(dk_ref.shape, dtype=jnp.float32)
    dv_acc = jnp.zeros(dv_ref.shape, dtype=jnp.float32)

    def body(t, carry):
        dk_acc, dv_acc = carry
        idx = pl.dslice(t * BLOCK_R, BLOCK_R)
        q_blk = plgpu.load(q_ref.at[0, idx, :])        # bf16
        do_blk = plgpu.load(do_ref.at[0, idx, :])      # bf16
        logsumexp_blk = plgpu.load(logsumexp_ref.at[0, idx])  # bf16
        d_blk = plgpu.load(d_ref.at[0, idx])           # bf16
        
        # Recompute P = softmax(Q @ K^T / scale) using logsumexp
        s_blk = pl.dot(q_blk, k_reg, trans_b=True) / scale  # float32
        p_blk = jnp.exp(s_blk - logsumexp_blk[..., None])   # float32
        
        # dP = dO @ V^T, dS = P * (dP - D) / scale
        dp_blk = pl.dot(do_blk, v_reg, trans_b=True)  # float32
        ds_blk = p_blk * (dp_blk - d_blk[..., None]) / scale  # float32
        
        # Accumulate: dV += P^T @ dO, dK += dS^T @ Q
        # Cast P, dS to bf16 for tensor core matmuls
        dv_acc += pl.dot(p_blk.astype(do_blk.dtype), do_blk, trans_a=True)
        dk_acc += pl.dot(ds_blk.astype(q_blk.dtype), q_blk, trans_a=True)
        return dk_acc, dv_acc
        
    dk_acc, dv_acc = jax.lax.fori_loop(0, num_q_blocks, body, (dk_acc, dv_acc))
    plgpu.store(dk_ref, dk_acc.astype(dk_ref.dtype))
    plgpu.store(dv_ref, dv_acc.astype(dv_ref.dtype))


def flash_attention_bwd_dkv(q_flat, k_flat, v_flat, do_flat, logsumexp_flat, d_flat, scale):
    """Compute dK and dV using pallas_call."""
    B_flat, T, C = q_flat.shape
    num_q_blocks = pl.cdiv(T, BLOCK_R)
    grid = (B_flat, pl.cdiv(T, BLOCK_C))

    dk_flat, dv_flat = pl.pallas_call(
        partial(flash_attention_bwd_dkv_kernel, scale=scale, num_q_blocks=num_q_blocks),
        out_shape=[
            jax.ShapeDtypeStruct(k_flat.shape, k_flat.dtype),
            jax.ShapeDtypeStruct(v_flat.shape, v_flat.dtype),
        ],
        grid=grid,
        in_specs=[
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0)),       # q (full)
            pl.BlockSpec((1, BLOCK_C, C), lambda b, t: (b, t, 0)), # k (blocked)
            pl.BlockSpec((1, BLOCK_C, C), lambda b, t: (b, t, 0)), # v (blocked)
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0)),       # do (full)
            pl.BlockSpec((1, T), lambda b, _: (b, 0)),             # logsumexp (full)
            pl.BlockSpec((1, T), lambda b, _: (b, 0)),             # d (full)
        ],
        out_specs=[
            pl.BlockSpec((1, BLOCK_C, C), lambda b, t: (b, t, 0)),
            pl.BlockSpec((1, BLOCK_C, C), lambda b, t: (b, t, 0)),
        ],
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(num_warps=NUM_WARPS, num_stages=NUM_STAGES)
    )(q_flat, k_flat, v_flat, do_flat, logsumexp_flat, d_flat)
    return dk_flat, dv_flat

In [None]:

# Kernel 3: dQ - outer loop over Q blocks, inner loop over KV blocks
def flash_attention_bwd_dq_kernel(
    q_ref, k_ref, v_ref, do_ref, logsumexp_ref, d_ref,
    dq_ref,
    *, scale, num_kv_blocks
):
    """Compute dQ gradient.
    
    Precision strategy: same as dK/dV kernel.
    """
    q_reg = plgpu.load(q_ref.at[0, :, :])              # bf16
    do_reg = plgpu.load(do_ref.at[0, :, :])            # bf16
    logsumexp_reg = plgpu.load(logsumexp_ref.at[0, :]) # bf16
    d_reg = plgpu.load(d_ref.at[0, :])                 # bf16
    dq_acc = jnp.zeros(dq_ref.shape, dtype=jnp.float32)  # float32 accumulator

    def body(t, carry):
        dq_acc = carry
        idx = pl.dslice(t * BLOCK_C, BLOCK_C)
        k_blk = plgpu.load(k_ref.at[0, idx, :])  # bf16
        v_blk = plgpu.load(v_ref.at[0, idx, :])  # bf16
        
        # Recompute P
        s_blk = pl.dot(q_reg, k_blk, trans_b=True) / scale  # float32
        p_blk = jnp.exp(s_blk - logsumexp_reg[..., None])   # float32
        
        # dP = dO @ V^T, dS = P * (dP - D) / scale
        dp_blk = pl.dot(do_reg, v_blk, trans_b=True)  # float32
        ds_blk = p_blk * (dp_blk - d_reg[..., None]) / scale  # float32
        
        # Accumulate: dQ += dS @ K (cast dS to bf16 for tensor cores)
        dq_acc += pl.dot(ds_blk.astype(k_blk.dtype), k_blk)
        return dq_acc

    dq_acc = jax.lax.fori_loop(0, num_kv_blocks, body, dq_acc)
    plgpu.store(dq_ref, dq_acc.astype(dq_ref.dtype))


def flash_attention_bwd_dq(q_flat, k_flat, v_flat, do_flat, logsumexp_flat, d_flat, scale):
    """Compute dQ using pallas_call."""
    B_flat, T, C = q_flat.shape
    num_kv_blocks = pl.cdiv(T, BLOCK_C)
    grid = (B_flat, pl.cdiv(T, BLOCK_R))

    dq_flat = pl.pallas_call(
        partial(flash_attention_bwd_dq_kernel, scale=scale, num_kv_blocks=num_kv_blocks),
        out_shape=jax.ShapeDtypeStruct(q_flat.shape, q_flat.dtype),
        grid=grid,
        in_specs=[
            pl.BlockSpec((1, BLOCK_R, C), lambda b, t: (b, t, 0)), # q (blocked)
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0)),       # k (full)
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0)),       # v (full)
            pl.BlockSpec((1, BLOCK_R, C), lambda b, t: (b, t, 0)), # do (blocked)
            pl.BlockSpec((1, BLOCK_R), lambda b, t: (b, t)),       # logsumexp (blocked)
            pl.BlockSpec((1, BLOCK_R), lambda b, t: (b, t)),       # d (blocked)
        ],
        out_specs=pl.BlockSpec((1, BLOCK_R, C), lambda b, t: (b, t, 0)),
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(num_warps=NUM_WARPS, num_stages=NUM_STAGES)
    )(q_flat, k_flat, v_flat, do_flat, logsumexp_flat, d_flat)
    return dq_flat


@jax.jit
def flash_attention_bwd(q, k, v, o, logsumexp, do):
    """Flash attention backward pass using 3 separate kernels."""
    B, H, T, C = q.shape
    scale = math.sqrt(C)

    # Flatten batch and head dimensions
    q_flat = q.reshape(-1, T, C)
    k_flat = k.reshape(-1, T, C)
    v_flat = v.reshape(-1, T, C)
    o_flat = o.reshape(-1, T, C)
    do_flat = do.reshape(-1, T, C)
    logsumexp_flat = logsumexp.reshape(-1, T)

    # Kernel 1: Preprocess - compute D = rowsum(O * dO)
    d_flat = flash_attention_bwd_preprocess(o_flat, do_flat)

    # Kernel 2: Compute dK, dV
    dk_flat, dv_flat = flash_attention_bwd_dkv(
        q_flat, k_flat, v_flat, do_flat, logsumexp_flat, d_flat, scale
    )

    # Kernel 3: Compute dQ
    dq_flat = flash_attention_bwd_dq(
        q_flat, k_flat, v_flat, do_flat, logsumexp_flat, d_flat, scale
    )

    return (
        dq_flat.reshape(q.shape),
        dk_flat.reshape(k.shape),
        dv_flat.reshape(v.shape),
    )

## Precision Optimization: From Float32 to Bfloat16

A key optimization in our implementation is the careful management of numerical precision. The naive approach of casting everything to float32 wastes memory bandwidth, while pure bfloat16 causes numerical instability. Our optimized approach uses **mixed precision**: bfloat16 for memory transfers and tensor core operations, float32 for sensitive intermediate computations.

### The Problem with Full Float32

Our initial implementation cast all inputs to float32 immediately after loading:

```python
# BEFORE: Suboptimal - wastes memory bandwidth
q_reg = plgpu.load(q_ref.at[0, :, :]).astype(jnp.float32)  # 2x bandwidth
k_blk = plgpu.load(k_ref.at[0, idx, :]).astype(jnp.float32)
v_blk = plgpu.load(v_ref.at[0, idx, :]).astype(jnp.float32)
```

This approach has two problems:
1. **Double memory bandwidth**: Loading 4 bytes per element instead of 2
2. **Slower tensor cores**: Float32 matmuls use TF32 tensor cores, which are slower than bf16 tensor cores

### The Mixed Precision Strategy

Our optimized implementation follows these principles:

| Operation | Dtype | Reason |
|-----------|-------|--------|
| Load Q, K, V, dO | bf16 | Half the memory bandwidth |
| Matmul inputs | bf16 | Fast bf16 tensor cores |
| Matmul outputs | float32 | Tensor cores accumulate in float32 |
| Softmax (exp, max, sum) | float32 | Numerical stability |
| Running accumulators | float32 | Avoid precision loss across blocks |
| Store outputs | bf16 | Match input dtype |

### Key Implementation Details

**1. Keep tensor loads as bfloat16:**
```python
# AFTER: Optimal - native bf16 loads
q_reg = plgpu.load(q_ref.at[0, :, :])  # bf16, half the bandwidth
k_blk = plgpu.load(k_ref.at[0, idx, :])  # bf16
v_blk = plgpu.load(v_ref.at[0, idx, :])  # bf16
```

**2. Matmuls naturally output float32:**
```python
# Q @ K^T: bf16 inputs, but tensor cores accumulate in float32
s_blk = pl.dot(q_reg, k_blk, trans_b=True) / scale  # Output is float32
```

**3. Keep softmax computation in float32:**
```python
# These operations need float32 precision
max_blk = jnp.maximum(max_reg, jnp.max(s_blk, axis=-1))  # float32
s_blk = jnp.exp(s_blk - max_blk[:, None])  # float32 - exp is sensitive!
l_blk = jnp.sum(s_blk, axis=-1)  # float32 accumulation
```

**4. Cast back to bf16 for the next matmul:**
```python
# P @ V: cast P (attention weights) back to bf16 for fast tensor cores
o_blk = pl.dot(s_blk.astype(v_blk.dtype), v_blk)  # bf16 inputs, float32 output
```

**5. Accumulators must be float32:**
```python
# These accumulate across many blocks - bf16 would lose small contributions
o_reg = jnp.zeros(q_reg.shape, jnp.float32)
max_reg = jnp.full((BLOCK_R,), -jnp.inf, dtype=jnp.float32)
l_reg = jnp.zeros((BLOCK_R,), dtype=jnp.float32)
```

### Performance Impact

The precision optimization significantly improves performance by:

1. **Halving memory bandwidth** for Q, K, V, O, dO loads/stores
2. **Using faster bf16 tensor cores** for matrix multiplications
3. **Maintaining numerical correctness** through float32 intermediates

Before optimization (float32 everywhere):
```
Forward:  0.642 ms  (T=1024)
Backward: 7.309 ms  (T=1024)
```

After optimization (mixed bf16/float32):
```
Forward:  0.294 ms  (T=1024)  - 2.2x faster
Backward: 0.943 ms  (T=1024)  - 7.7x faster
```

The backward pass sees a larger improvement because it has more memory traffic (loading Q, K, V, O, dO, logsumexp, D) that benefits from the reduced bandwidth.

### Why Certain Values Must Stay Float32

**Running max (`max_reg`)**: Could technically be bf16 since it's just tracking maximums, but keeping it float32 costs nothing (only BLOCK_R=64 elements) and avoids edge cases.

**Running sum (`l_reg`)**: Must be float32. It accumulates across all K blocks:
```python
l_reg = l_reg * jnp.exp(max_reg - max_blk) + l_blk
```
With T=4096 and BLOCK_C=64, that's 64 iterations. Bf16 would lose small contributions when adding to large sums.

**Logsumexp**: Used in backward pass as `exp(s_blk - logsumexp)`. Errors in the exponent get amplified exponentially.

**Output accumulator (`o_reg`)**: Same accumulation issue as `l_reg` - must be float32 to avoid losing small corrections.

## Custom VJP Integration

- Wiring up forward and backward with `jax.custom_vjp`
- The residuals needed (Q, K, V, O, logsumexp)

In [6]:
@jax.custom_vjp
def flash_attention(q, k, v):
    """Flash attention with custom backward pass."""
    o, _ = flash_attention_fwd(q, k, v)
    return o


def flash_attention_fwd_rule(q, k, v):
    """Forward rule for custom_vjp.
    
    Returns the output and residuals needed for backward pass.
    """
    o, logsumexp = flash_attention_fwd(q, k, v)
    return o, (q, k, v, o, logsumexp)


def flash_attention_bwd_rule(res, do):
    """Backward rule for custom_vjp.
    
    Takes residuals from forward and upstream gradient dO,
    returns gradients (dQ, dK, dV).
    """
    q, k, v, o, logsumexp = res
    dq, dk, dv = flash_attention_bwd(q, k, v, o, logsumexp, do)
    return dq, dk, dv


flash_attention.defvjp(flash_attention_fwd_rule, flash_attention_bwd_rule)

## Evaluation

We verify correctness by comparing our flash attention implementation against the reference (materialized) attention for both forward and backward passes.

In [None]:
B, H, T, D = 2, 4, 256, 64
key = jax.random.key(0)
keys = jax.random.split(key, 4)

# Use bfloat16 for optimal performance
q = jax.random.normal(keys[0], (B, H, T, D), dtype=jnp.bfloat16)
k = jax.random.normal(keys[1], (B, H, T, D), dtype=jnp.bfloat16)
v = jax.random.normal(keys[2], (B, H, T, D), dtype=jnp.bfloat16)
do = jax.random.normal(keys[3], (B, H, T, D), dtype=jnp.bfloat16)

# Forward check
o_ref = mha_reference(q, k, v)
print(f"Reference output shape: {o_ref.shape}")

o_flash = flash_attention(q, k, v)
print(f"Flash attention output shape: {o_flash.shape}")
print(f"Forward pass matches: {jnp.allclose(o_flash, o_ref, atol=1e-2, rtol=1e-2)}")

In [None]:
# Backward check
def loss_ref(q, k, v):
    return jnp.sum(mha_reference(q, k, v) * do)

dq_ref, dk_ref, dv_ref = jax.grad(loss_ref, argnums=(0, 1, 2))(q, k, v)
print(f"Reference gradient shapes: dq={dq_ref.shape}, dk={dk_ref.shape}, dv={dv_ref.shape}")

# Flash attention backward pass
def loss_flash(q, k, v):
    return jnp.sum(flash_attention(q, k, v) * do)

dq_flash, dk_flash, dv_flash = jax.grad(loss_flash, argnums=(0, 1, 2))(q, k, v)
print(f"Flash attention gradient shapes: dq={dq_flash.shape}, dk={dk_flash.shape}, dv={dv_flash.shape}")

print(f"dQ matches: {jnp.allclose(dq_flash, dq_ref, atol=1e-2, rtol=1e-2)}")
print(f"dK matches: {jnp.allclose(dk_flash, dk_ref, atol=1e-2, rtol=1e-2)}")
print(f"dV matches: {jnp.allclose(dv_flash, dv_ref, atol=1e-2, rtol=1e-2)}")

## Performance Comparison

We compare our Pallas flash attention implementation against:
1. **JAX cuDNN**: `jax.nn.dot_product_attention(implementation='cudnn')` - NVIDIA's highly optimized implementation
2. **Reference (materialized)**: Standard attention that materializes the full N×N attention matrix

Note: The cuDNN implementation requires a GPU with cuDNN installed and uses float16 for optimal performance. Set `INTERPRET_MODE = False` to run on GPU.

In [9]:
import time

def bench(fn, *args, iters=10):
    for _ in range(3):  # warmup
        result = fn(*args)
        if isinstance(result, tuple):
            result[0].block_until_ready()
        else:
            result.block_until_ready()
    times = []
    for _ in range(iters):
        t0 = time.perf_counter()
        result = fn(*args)
        if isinstance(result, tuple):
            result[0].block_until_ready()
        else:
            result.block_until_ready()
        times.append(time.perf_counter() - t0)
    return sum(times) / len(times)

In [None]:
# Performance benchmark (requires GPU with cuDNN)
# Skip this cell if running on CPU

def benchmark_attention():
    """Benchmark attention implementations."""
    import time
    
    # Use bfloat16 for optimal tensor core performance
    B, H, T, D = 4, 8, 1024, 64
    key = jax.random.key(42)
    keys = jax.random.split(key, 4)
    
    q = jax.random.normal(keys[0], (B, H, T, D), dtype=jnp.bfloat16)
    k = jax.random.normal(keys[1], (B, H, T, D), dtype=jnp.bfloat16)
    v = jax.random.normal(keys[2], (B, H, T, D), dtype=jnp.bfloat16)
    do = jax.random.normal(keys[3], (B, H, T, D), dtype=jnp.bfloat16)
    
    print(f"Benchmark shape: B={B}, H={H}, T={T}, D={D}, dtype=bfloat16")
    print("=" * 60)
    
    def bench_fwd(fn, q, k, v, iters=20):
        # Warmup
        for _ in range(3):
            out = fn(q, k, v)
            jax.block_until_ready(out)
        # Bench
        times = []
        for _ in range(iters):
            t0 = time.perf_counter()
            out = fn(q, k, v)
            jax.block_until_ready(out)
            times.append(time.perf_counter() - t0)
        return sum(times) / len(times) * 1000  # ms

    def bench_bwd(fn, q, k, v, do, iters=20):
        # Warmup
        for _ in range(3):
            grads = jax.grad(lambda q, k, v: jnp.sum(fn(q, k, v) * do), argnums=(0, 1, 2))(q, k, v)
            jax.block_until_ready(grads)
        # Bench
        times = []
        for _ in range(iters):
            t0 = time.perf_counter()
            grads = jax.grad(lambda q, k, v: jnp.sum(fn(q, k, v) * do), argnums=(0, 1, 2))(q, k, v)
            jax.block_until_ready(grads)
            times.append(time.perf_counter() - t0)
        return sum(times) / len(times) * 1000  # ms

    # JAX cuDNN (requires GPU)
    @jax.jit
    def jax_cudnn_attention(q, k, v):
        # Transpose from (B, H, T, D) to (B, T, H, D) for jax.nn.dot_product_attention
        q_t = jnp.transpose(q, (0, 2, 1, 3))
        k_t = jnp.transpose(k, (0, 2, 1, 3))
        v_t = jnp.transpose(v, (0, 2, 1, 3))
        out = jax.nn.dot_product_attention(q_t, k_t, v_t, implementation='cudnn')
        return jnp.transpose(out, (0, 2, 1, 3))

    # Our Pallas implementation
    @jax.jit
    def pallas_attention(q, k, v):
        return flash_attention(q, k, v)

    # Reference (materialized attention matrix)
    def reference_attention(q, k, v):
        return mha_reference(q, k, v)

    print("\nForward pass:")
    try:
        t_cudnn = bench_fwd(jax_cudnn_attention, q, k, v)
        print(f"  JAX cuDNN:              {t_cudnn:.3f} ms")
    except Exception as e:
        print(f"  JAX cuDNN:              N/A (cuDNN not available)")
        t_cudnn = None
    
    t_pallas = bench_fwd(pallas_attention, q, k, v)
    print(f"  Our Pallas:             {t_pallas:.3f} ms")
    
    t_ref = bench_fwd(reference_attention, q, k, v)
    print(f"  Reference (materialized): {t_ref:.3f} ms")
    
    if t_cudnn:
        print(f"\n  Pallas vs cuDNN: {t_pallas/t_cudnn:.2f}x")

    print("\nBackward pass:")
    try:
        t_cudnn_bwd = bench_bwd(jax_cudnn_attention, q, k, v, do)
        print(f"  JAX cuDNN:              {t_cudnn_bwd:.3f} ms")
    except Exception as e:
        print(f"  JAX cuDNN:              N/A (cuDNN not available)")
        t_cudnn_bwd = None
    
    t_pallas_bwd = bench_bwd(pallas_attention, q, k, v, do)
    print(f"  Our Pallas:             {t_pallas_bwd:.3f} ms")
    
    t_ref_bwd = bench_bwd(reference_attention, q, k, v, do)
    print(f"  Reference (materialized): {t_ref_bwd:.3f} ms")
    
    if t_cudnn_bwd:
        print(f"\n  Pallas vs cuDNN: {t_pallas_bwd/t_cudnn_bwd:.2f}x")

# Uncomment to run benchmark (requires GPU):
benchmark_attention()

### Example Results (RTX 4000 Ada)

When run on an NVIDIA RTX 4000 Ada GPU with the optimized bfloat16 implementation, we achieve performance competitive with cuDNN:

```
========================================================================================================================
FORWARD PASS SUMMARY
========================================================================================================================
T      Naive      Flash      cuDNN      Naive        Flash        cuDNN        Naive      Flash      cuDNN     
       (ms)       (ms)       (ms)       (GFLOP/s)    (GFLOP/s)    (GFLOP/s)    (AI)       (AI)       (AI)      
------------------------------------------------------------------------------------------------------------------------
128    0.166      0.116      0.199      822          1183         688          22         32         32        
256    0.127      0.147      0.228      4294         3720         2401         33         65         65        
512    0.235      0.290      0.171      9331         7543         12836        44         130        130       
1024   1.039      0.294      0.335      8429         29833        26127        52         260        260       
2048   4.590      0.746      0.891      7631         46967        39311        58         520        520       
4096   17.383     2.372      2.294      8061         59075        61082        61         1040       1040      

========================================================================================================================
BACKWARD PASS SUMMARY
========================================================================================================================
T      Naive      Flash      cuDNN      Naive        Flash        cuDNN        Naive      Flash      cuDNN     
       (ms)       (ms)       (ms)       (GFLOP/s)    (GFLOP/s)    (GFLOP/s)    (AI)       (AI)       (AI)      
------------------------------------------------------------------------------------------------------------------------
128    0.398      0.304      0.432      675          1545         777          26         56         40        
256    0.462      0.301      0.464      2324         6232         2893         43         112        80        
512    0.630      0.469      0.493      6816         16028        10882        64         224        160       
1024   2.408      0.943      0.976      7136         31888        22008        85         447        319       
2048   8.582      2.710      2.220      8007         44375        38694        102        894        639       
4096   31.144     9.422      7.106      8826         51053        48351        114        1789       1278      
```

**Key observations:**
- **Forward pass**: Our Pallas implementation matches cuDNN at large sequence lengths (T≥1024), achieving ~59 TFLOP/s at T=4096
- **Backward pass**: Our implementation consistently **outperforms cuDNN** in GFLOP/s throughput, achieving up to 51 TFLOP/s vs cuDNN's 48 TFLOP/s
- **Higher arithmetic intensity**: Our backward pass achieves ~1.4x higher AI than cuDNN (1789 vs 1278 at T=4096), indicating better data reuse
- **Massive speedup over naive**: Both flash implementations are 4-5x faster than naive attention at long sequences

## Roofline Analysis: Understanding Performance Bottlenecks

The roofline model is a visual framework for understanding whether a kernel is **compute-bound** or **memory-bound**. It helps explain why flash attention significantly outperforms naive attention despite doing the same mathematical computation.

### The Roofline Model

The roofline model plots **Arithmetic Intensity (AI)** on the x-axis against **Performance (GFLOP/s)** on the y-axis:

- **Arithmetic Intensity (AI)** = FLOPs / Bytes transferred
  - Measures how much computation you do per byte of data moved
  - Higher AI means the kernel reuses data more efficiently
  
- **Performance** = Achieved GFLOP/s
  - How fast the kernel actually runs

The "roofline" consists of two lines:
1. **Memory Roof** (diagonal): `Performance = Bandwidth × AI`
   - When AI is low, performance is limited by how fast you can move data
2. **Compute Roof** (horizontal): `Performance = Peak TFLOP/s`
   - When AI is high, performance is limited by how fast you can compute

The intersection is called the **ridge point**:
$$\text{Ridge AI} = \frac{\text{Peak Compute (FLOP/s)}}{\text{Peak Bandwidth (Bytes/s)}}$$

Kernels with AI below the ridge are memory-bound; above the ridge are compute-bound.

### FLOPS Calculation for Attention

For attention with shape `(B, H, T, D)` where B=batch, H=heads, T=sequence length, D=head dimension:

#### Forward Pass (same for all implementations)

The forward pass computes `softmax(Q @ K^T / √d) @ V`:

1. **Q @ K^T**: Matrix multiply of `(T, D) × (D, T) → (T, T)`
   - Each element requires D multiplications and D-1 additions ≈ 2D FLOPs
   - Total: `B × H × T × T × 2D` FLOPs
   
2. **Softmax**: For each row of the T×T attention matrix:
   - Subtract max (T ops), exp (T ops), sum (T ops), divide (T ops) ≈ 5T ops per row
   - Total: `B × H × T × 5T` = `5 × B × H × T²` FLOPs
   
3. **P @ V**: Matrix multiply of `(T, T) × (T, D) → (T, D)`
   - Total: `B × H × T × T × 2D` FLOPs

**Total Forward FLOPs** = `4 × B × H × T² × D + 5 × B × H × T²`

For large T and D, the `4 × B × H × T² × D` term dominates.

#### Backward Pass (varies by implementation)

The backward pass is where naive and flash attention differ significantly:

**Naive Attention Backward** (stores full attention matrix):
- dV = P^T @ dO: `2 × T² × D`
- dP = dO @ V^T: `2 × T² × D`
- dQ = dS @ K: `2 × T² × D`
- dK = dS^T @ Q: `2 × T² × D`
- **Total: `8 × B × H × T² × D`**

**Pallas Flash Attention Backward** (recomputes attention twice):
- dK/dV kernel: recomputes S = Q @ K^T, then dP, dV, dK (4 matmuls)
- dQ kernel: recomputes S = Q @ K^T, then dP, dQ (3 matmuls)
- **Total: `14 × B × H × T² × D`**

**cuDNN Flash Attention Backward** (optimized single recompute):
- Fused backward: recomputes S once, computes dQ, dK, dV together
- **Total: ~`10 × B × H × T² × D`**

### Memory Transfer (Bytes) Calculation

The key insight of flash attention is **reducing memory traffic**, not FLOPs. Here's where the implementations differ:

#### Forward Pass Memory Traffic

**Naive MHA Forward** (materializes full attention matrix):
- Read Q, K, V: `3 × B × H × T × D × bytes_per_elem`
- Write attention matrix P: `B × H × T × T × bytes_per_elem` ← **THE BIG ONE!**
- Write output O: `B × H × T × D × bytes_per_elem`

**Flash Attention Forward** (tiled, no attention matrix):
- Read Q, K, V: `3 × B × H × T × D × bytes_per_elem`
- Write logsumexp: `B × H × T × bytes_per_elem` ← **Much smaller!**
- Write output O: `B × H × T × D × bytes_per_elem`

The difference is **O(T²) vs O(T)**. For sequence length T=1024, the attention matrix alone requires T²=1M elements per head, while logsumexp only requires T=1K elements.

#### Backward Pass Memory Traffic

**Naive MHA Backward**:
- Read Q, K, V, O, dO: `5 × B × H × T × D × bytes_per_elem`
- Read attention matrix: `B × H × T × T × bytes_per_elem`
- Write dQ, dK, dV: `3 × B × H × T × D × bytes_per_elem`

**Flash Attention Backward**:
- Read Q, K, V, O, dO: `5 × B × H × T × D × bytes_per_elem`
- Read logsumexp: `B × H × T × bytes_per_elem`
- Write dQ, dK, dV: `3 × B × H × T × D × bytes_per_elem`

### Example Roofline Benchmark Output

Here's example output from running the roofline benchmark on an RTX 4000 Ada with bfloat16:

```
GPU: NVIDIA RTX 4000 Ada Generation
  Peak Compute (Tensor cores): 106.9 TFLOP/s
  Peak Memory Bandwidth:       360.0 GB/s

========================================================================================================================
FORWARD PASS SUMMARY
========================================================================================================================
T      Naive      Flash      cuDNN      Naive        Flash        cuDNN        Naive      Flash      cuDNN     
       (ms)       (ms)       (ms)       (GFLOP/s)    (GFLOP/s)    (GFLOP/s)    (AI)       (AI)       (AI)      
------------------------------------------------------------------------------------------------------------------------
1024   1.039      0.294      0.335      8429         29833        26127        52         260        260       
2048   4.590      0.746      0.891      7631         46967        39311        58         520        520       
4096   17.383     2.372      2.294      8061         59075        61082        61         1040       1040      

========================================================================================================================
BACKWARD PASS SUMMARY
========================================================================================================================
T      Naive      Flash      cuDNN      Naive        Flash        cuDNN        Naive      Flash      cuDNN     
       (ms)       (ms)       (ms)       (GFLOP/s)    (GFLOP/s)    (GFLOP/s)    (AI)       (AI)       (AI)      
------------------------------------------------------------------------------------------------------------------------
1024   2.408      0.943      0.976      7136         31888        22008        85         447        319       
2048   8.582      2.710      2.220      8007         44375        38694        102        894        639       
4096   31.144     9.422      7.106      8826         51053        48351        114        1789       1278      

Peak verification:
  Pallas fwd: 59075 GFLOP/s (55.3% of TC peak)
  Pallas bwd: 51053 GFLOP/s (47.8% of TC peak)
  cuDNN fwd:  61082 GFLOP/s (57.1% of TC peak)
  cuDNN bwd:  48351 GFLOP/s (45.2% of TC peak)
```

**Reading the table:**
- **Time (ms)**: Lower is better - flash and cuDNN are much faster than naive
- **GFLOP/s**: Higher is better - our flash achieves 50-60 TFLOP/s, competitive with cuDNN
- **AI**: Arithmetic intensity - flash has ~1.4x higher AI than cuDNN in backward pass

The key insight is that our Pallas implementation achieves **higher throughput than cuDNN in the backward pass** (51 vs 48 TFLOP/s at T=4096). This is due to our higher arithmetic intensity from the tiling strategy.

## Limitations and Future Work

### Performance Achievement

Our Pallas implementation now achieves performance **competitive with NVIDIA's cuDNN flash attention**:

- **Forward pass**: Within 3-10% of cuDNN at large sequence lengths (T≥1024)
- **Backward pass**: Actually **outperforms cuDNN** in throughput at most sequence lengths

This was achieved through careful mixed-precision optimization (see the Precision Optimization section above).

### Remaining Gaps

1. **Small sequence lengths**: At T<512, kernel launch overhead dominates and naive attention can be faster. Production implementations often fall back to naive attention for short sequences.

2. **Forward pass at T=4096**: cuDNN is slightly faster (2.29ms vs 2.37ms). This is likely due to better autotuning of block sizes.

3. **No causal masking optimization**: Causal attention can skip computation for masked positions, but our implementation computes the full attention matrix.

### Pallas Limitations

Pallas provides a high-level abstraction for writing GPU kernels, but it doesn't expose certain low-level primitives:

- **No warp-level programming**: Pallas doesn't provide access to `warp_id` or warp shuffle operations. You can configure `num_warps` but cannot coordinate work between warps within a block.

- **Limited shared memory control**: Pallas manages shared memory implicitly through `BlockSpec`. You cannot explicitly allocate shared memory or control synchronization barriers.

- **No atomic operations**: Pallas on GPU doesn't expose `atomic_add`, requiring separate kernels for reductions (like our three-kernel backward pass).

### Path to Further Improvement

1. **Autotuning block sizes**: Our fixed BLOCK_R=BLOCK_C=64 may not be optimal for all configurations. Dynamic tuning could help.

2. **Fusing backward kernels**: The three-kernel approach adds overhead. With careful synchronization, these could potentially be fused.

3. **Causal masking**: Skip computation for masked positions in the attention matrix.

### Educational Value

Despite targeting production-level performance, this implementation retains significant educational value:

- **Algorithm clarity**: The tiled computation with online softmax correction is clearly visible in the code
- **Gradient derivation**: The backward pass shows exactly how gradients flow through attention
- **Precision analysis**: The mixed-precision strategy demonstrates real-world optimization thinking
- **Debugging**: `INTERPRET_MODE=True` allows stepping through the algorithm on CPU

## References

1. Dao, T., Fu, D., Ermon, S., Rudra, A., & Re, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. *NeurIPS 2022*. https://arxiv.org/abs/2205.14135
2. Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. *arXiv preprint arXiv:2307.08691*. https://arxiv.org/abs/2307.08691
3. JAX Official Flash Attention (TPU): https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py
4. JAX Official Fused Attention (GPU): https://github.com/jax-ml/jax/blob/main/jax/experimental/pallas/ops/gpu/attention.py
5. Umar Jamil's Triton Flash Attention: https://github.com/hkproj/triton-flash-attention
6. Sebastian Raschka - Understanding and Coding Self-Attention from Scratch: https://sebastianraschka.com/blog/2023/self-attention-from-scratch.html