# How to Write a Flash Attention Kernel in Pallas
## Introduction

- Link back to the softmax post (this builds on online softmax)
- Why attention is a bottleneck: O(N²) memory for the attention matrix
- Flash attention's promise: O(N) memory, same result

## Standard Attention

- Mathematical definition: `softmax(QK^T / √d) @ V`
- Naive JAX implementation with einsum
- The memory problem: for sequence length N, we materialize an N×N matrix

In [1]:
import jax
import jax.numpy as jnp

@jax.jit
def mha_reference(q, k, v):
    """Reference multi-head attention: softmax(Q @ K^T / sqrt(d)) @ V"""
    d = q.shape[-1]
    scale = 1.0 / jnp.sqrt(d)
    logits = jnp.einsum('bhqd,bhkd->bhqk', q, k) * scale
    probs = jax.nn.softmax(logits, axis=-1)
    o = jnp.einsum('bhqk,bhkd->bhqd', probs, v)
    return o

## The Flash Attention Algorithm

- Key insight: we never need the full attention matrix
- Combine online softmax with output accumulation
- Walk through the algorithm:
  - Tile Q (outer parallel loop)
  - Tile K, V (inner sequential loop)
  - Maintain running max `m`, sum `l`, and output accumulator `o`
  - Correction factor when max changes
  - Final normalization
- Python reference implementation (like your `online_softmax` function)

## Forward Pass Kernel

- BlockSpec design: Q tiled, K/V full sequence for inner loop
- The kernel implementation
- Storing logsumexp for backward pass

In [2]:
from functools import partial
import math

from jax.experimental import pallas as pl
from jax.experimental.pallas import triton as plgpu

INTERPRET_MODE = True  # Set to False on GPU

BLOCK_T = 64
NUM_WARPS = 4
NUM_STAGES = 2

In [3]:
def flash_attention_fwd_kernel(q_ref, k_ref, v_ref, o_ref, logsumexp_ref, *, scale, num_k_blocks):
    """Flash attention forward kernel."""
    q_reg = plgpu.load(q_ref.at[0, :, :]).astype(jnp.float32)
    o_reg = jnp.zeros(q_reg.shape, jnp.float32)
    max_reg = jnp.full((BLOCK_T,), -jnp.inf, dtype=jnp.float32)
    l_reg = jnp.zeros((BLOCK_T,), dtype=jnp.float32)

    def body(t, args):
        max_reg, l_reg, o_reg = args
        idx = pl.dslice(t * BLOCK_T, BLOCK_T)
        k_blk = plgpu.load(k_ref.at[0, idx, :]).astype(jnp.float32)
        v_blk = plgpu.load(v_ref.at[0, idx, :]).astype(jnp.float32)
        s_blk = pl.dot(q_reg, k_blk, trans_b=True) / scale
        max_blk = jnp.maximum(max_reg, jnp.max(s_blk, axis=-1))
        s_blk = jnp.exp(s_blk - max_blk[:, None])
        l_blk = jnp.sum(s_blk, axis=-1)
        o_blk = pl.dot(s_blk, v_blk)
        return (max_blk, 
                l_reg * jnp.exp(max_reg - max_blk) + l_blk, 
                o_reg * jnp.exp(max_reg - max_blk)[:, None] + o_blk)

    max_reg, l_reg, o_reg = jax.lax.fori_loop(0, num_k_blocks, body, (max_reg, l_reg, o_reg))
    o_reg = o_reg / l_reg[:, None]
    logsumexp_reg = max_reg + jnp.log(l_reg)
    plgpu.store(o_ref.at[0, :, :], o_reg.astype(o_ref.dtype))
    plgpu.store(logsumexp_ref.at[0, :], logsumexp_reg.astype(logsumexp_ref.dtype))

In [4]:
@jax.jit
def flash_attention_fwd(q, k, v):
    """Flash attention forward pass."""
    B, H, T, C = q.shape
    B_flat = B*H
    q_flat = q.reshape(-1, T, C)
    k_flat = k.reshape(-1, T, C)
    v_flat = v.reshape(-1, T, C)
    scale = math.sqrt(C)
    num_k_blocks = pl.cdiv(T, BLOCK_T)

    grid = (B_flat, pl.cdiv(T, BLOCK_T))

    out_flat, logsumexp = pl.pallas_call(
        partial(flash_attention_fwd_kernel, scale=scale, num_k_blocks=num_k_blocks),
        out_shape=[
            jax.ShapeDtypeStruct(q_flat.shape, q_flat.dtype),
            jax.ShapeDtypeStruct((B*H, T), q_flat.dtype)
        ],
        grid=grid,
        in_specs=[
            pl.BlockSpec((1, BLOCK_T, C), lambda b, t: (b, t, 0)),
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0)),
            pl.BlockSpec((1, T, C), lambda b, _: (b, 0, 0))
        ],
        out_specs=[
            pl.BlockSpec((1, BLOCK_T, C), lambda b, t: (b, t, 0)),
            pl.BlockSpec((1, BLOCK_T), lambda b, t: (b, t))
        ],
        interpret=INTERPRET_MODE,
        compiler_params=plgpu.CompilerParams(
            num_warps=NUM_WARPS,
            num_stages=NUM_STAGES
        )
    )(q_flat, k_flat, v_flat)
    out = out_flat.reshape(q.shape)
    logsumexp = logsumexp.reshape(B, H, T)
    return out, logsumexp

## Performance

- Benchmark against JAX's standard attention
- Memory comparison (if possible to demonstrate)

In [5]:
import time

def bench(fn, *args, iters=10):
    for _ in range(3):  # warmup
        result = fn(*args)
        if isinstance(result, tuple):
            result[0].block_until_ready()
        else:
            result.block_until_ready()
    times = []
    for _ in range(iters):
        t0 = time.perf_counter()
        result = fn(*args)
        if isinstance(result, tuple):
            result[0].block_until_ready()
        else:
            result.block_until_ready()
        times.append(time.perf_counter() - t0)
    return sum(times) / len(times)

## Backward Pass

- Mathematical derivation of gradients dQ, dK, dV
- The recomputation strategy: recompute attention weights from logsumexp
- Two-pass approach: one for dK/dV, one for dQ (as in the reference)
- Kernel implementation

In [6]:
def flash_attention_bwd_kernel(q_ref, k_ref, v_ref, o_ref, do_ref, logsumexp_ref,
                                dq_ref, dk_ref, dv_ref, *, num_kv_blocks, scale):
    """Flash attention backward kernel."""
    pass


@jax.jit
def flash_attention_bwd(q, k, v, o, logsumexp, do):
    """Flash attention backward pass."""
    pass

## Custom VJP Integration

- Wiring up forward and backward with `jax.custom_vjp`
- The residuals needed (Q, K, V, O, logsumexp)

In [7]:
@jax.custom_vjp
def flash_attention(q, k, v):
    """Flash attention with custom backward pass."""
    o, logsumexp = flash_attention_fwd(q, k, v)
    return o


def flash_attention_fwd_rule(q, k, v):
    """Forward rule for custom_vjp."""
    o, logsumexp = flash_attention_fwd(q, k, v)
    return o, (q, k, v, o, logsumexp)


def flash_attention_bwd_rule(res, do):
    """Backward rule for custom_vjp."""
    q, k, v, o, logsumexp = res
    dq, dk, dv = flash_attention_bwd(q, k, v, o, logsumexp, do)
    return dq, dk, dv


flash_attention.defvjp(flash_attention_fwd_rule, flash_attention_bwd_rule)

## Evaluation

- Gradient correctness check against JAX autodiff
- Train a small transformer or attention layer to verify end-to-end

In [8]:
B, H, T, D = 2, 4, 256, 64
key = jax.random.key(0)
keys = jax.random.split(key, 4)

q = jax.random.normal(keys[0], (B, H, T, D), dtype=jnp.float32)
k = jax.random.normal(keys[1], (B, H, T, D), dtype=jnp.float32)
v = jax.random.normal(keys[2], (B, H, T, D), dtype=jnp.float32)
do = jax.random.normal(keys[3], (B, H, T, D), dtype=jnp.float32)

# Forward check
o_ref = mha_reference(q, k, v)
print(f"Reference output shape: {o_ref.shape}")

o_flash = flash_attention(q, k, v)
print(f"Flash attention output shape: {o_flash.shape}")
print(f"Forward pass matches: {jnp.allclose(o_flash, o_ref, atol=1e-2, rtol=1e-2)}")

Reference output shape: (2, 4, 256, 64)
Flash attention output shape: (2, 4, 256, 64)
Forward pass matches: True


In [9]:
# Backward check (reference)
def loss_ref(q, k, v):
    return jnp.sum(mha_reference(q, k, v) * do)

dq_ref, dk_ref, dv_ref = jax.grad(loss_ref, argnums=(0, 1, 2))(q, k, v)
print(f"Reference gradient shapes: dq={dq_ref.shape}, dk={dk_ref.shape}, dv={dv_ref.shape}")

# TODO: Uncomment once backward pass is implemented
# def loss_flash(q, k, v):
#     return jnp.sum(flash_attention(q, k, v) * do)
#
# dq_flash, dk_flash, dv_flash = jax.grad(loss_flash, argnums=(0, 1, 2))(q, k, v)
# print(f"Flash attention gradient shapes: dq={dq_flash.shape}, dk={dk_flash.shape}, dv={dv_flash.shape}")
# print(f"dQ matches: {jnp.allclose(dq_flash, dq_ref, atol=1e-2, rtol=1e-2)}")
# print(f"dK matches: {jnp.allclose(dk_flash, dk_ref, atol=1e-2, rtol=1e-2)}")
# print(f"dV matches: {jnp.allclose(dv_flash, dv_ref, atol=1e-2, rtol=1e-2)}")

Reference gradient shapes: dq=(2, 4, 256, 64), dk=(2, 4, 256, 64), dv=(2, 4, 256, 64)


## Conclusion

- Recap of key ideas
- Potential extensions: causal masking, multi-query attention
- Link to further resources (FlashAttention paper, JAX Pallas docs)

## References

- Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. *arXiv preprint arXiv:2307.08691*. https://arxiv.org/abs/2307.08691