
# From Naïve Attention to FlashAttention (in Pure PyTorch)
### Building the algorithm from first principles, one improvement at a time

This notebook is written in **blog-post style**. We’ll start from the standard “textbook” implementation of scaled dot-product attention, and then progressively refine it.

The journey is the point: we begin with the simplest possible code, understand why it becomes expensive, fix numerical stability, and then redesign the computation so we no longer need to materialize the `T×T` attention matrix. That final step is the core idea behind **FlashAttention**: do the same math, but stream it in blocks, keeping memory usage small and computation stable.

A follow-up post can take the final PyTorch version and translate it into a CUDA kernel. Here, we stay entirely in PyTorch so we can focus on clarity, correctness, and intuition.



## 0. Setup

We’ll use PyTorch and run on GPU if available. Even on CPU, you can follow the logic and verify correctness, though the benchmarks will be less dramatic.

Throughout, we use these shapes:

- `Q, K, V`: `[B, H, T, D]`
  - `B`: batch size
  - `H`: number of attention heads
  - `T`: sequence length
  - `D`: head dimension

We’ll implement attention per head (but keep the head dimension in the tensor so the code is realistic).


In [None]:

import math
import time
from dataclasses import dataclass

import torch

torch.manual_seed(0)

def get_device():
    return "cuda" if torch.cuda.is_available() else "cpu"

device = get_device()
device



## 1. What attention computes

Scaled dot-product attention is usually written as:

\begin{aligned}
S &= \frac{QK^\top}{\sqrt{D}} \\
A &= \mathrm{softmax}(S) \\
O &= AV
\end{aligned}

You can read this as a very specific kind of “content-based lookup”. For each query vector `q_i` (a row of `Q`), you compare it to every key vector `k_j` (rows of `K`) by taking dot products. Those dot products form the scores `S[i, j]`.

The softmax turns those scores into a probability distribution over positions `j`. Finally, `O[i]` becomes a weighted average of the value vectors `v_j`.

This is straightforward. The issue is that `S` has shape `[T, T]` per head, and it’s expensive to create when `T` gets large.



### A quick note on the scale factor

Why divide by `sqrt(D)`? If each component of `q` and `k` has roughly unit variance, then the dot product has variance proportional to `D`. As `D` grows, the raw dot products become larger in magnitude, and softmax becomes more peaked and unstable. The scale keeps scores in a friendlier range.



## 2. Version A: naïve attention (clear but memory hungry)

Let’s implement the most direct version. This matches the math exactly. It also allocates the full score matrix `S` of shape `[B, H, T, T]`.

When `T` is small, this is fine. When `T` is large, this becomes the bottleneck. The core reason FlashAttention exists is that you rarely want to explicitly store `S`.


In [None]:

def attention_naive(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, causal: bool = False) -> torch.Tensor:
    """
    Naïve scaled dot-product attention.
    Allocates [B,H,T,T] score matrix.

    Q, K, V: [B, H, T, D]
    Returns: [B, H, T, D]
    """
    assert Q.ndim == 4 and K.ndim == 4 and V.ndim == 4
    B, H, T, D = Q.shape
    assert K.shape == (B, H, T, D)
    assert V.shape == (B, H, T, D)

    scale = 1.0 / math.sqrt(D)

    scores = torch.matmul(Q, K.transpose(-2, -1)) * scale  # [B,H,T,T]

    if causal:
        mask = torch.triu(torch.ones(T, T, device=scores.device, dtype=torch.bool), diagonal=1)
        scores = scores.masked_fill(mask, float("-inf"))

    attn = torch.softmax(scores, dim=-1)                    # [B,H,T,T]
    out = torch.matmul(attn, V)                             # [B,H,T,D]
    return out



## 3. The first practical problem: numerical stability

Even if we ignore memory for a moment, softmax itself can be numerically tricky.

Softmax is:

\[
\mathrm{softmax}(s)_j = \frac{\exp(s_j)}{\sum_k \exp(s_k)}
\]

If any `s_j` is large, `exp(s_j)` can overflow in finite precision. In practice, implementations compute:

\[
\mathrm{softmax}(s)_j = \frac{\exp(s_j - m)}{\sum_k \exp(s_k - m)}, \quad m = \max_k s_k
\]

Subtracting `m` does not change the result, but it prevents overflow.

PyTorch’s `torch.softmax` already does this internally, but it’s worth calling out because the stable trick becomes essential when we start streaming in blocks and computing softmax ourselves.



## 4. The second practical problem: the `T×T` matrix

For a single head, the score matrix contains `T×T` elements. For `T=4096`, that is about 16.7 million entries.
If you store them in fp16 (2 bytes), that’s around 33 MB per head just for the scores.

The key observation behind FlashAttention is:

To compute `O = softmax(S)V`, you do not need to store all of `S` at once.

You can compute it in tiles: load a block of keys/values, compute a block of scores, update a running output accumulator, and then move on to the next block.



## 5. Version B: tiled computation (but still not enough)

A natural first attempt is: compute scores in tiles and softmax each tile separately.

Unfortunately, that is not equivalent to the full softmax. Softmax normalizes across *all* keys. If you softmax block-by-block, each block is normalized independently, which changes the result.

So tiling alone is not sufficient. We need tiling plus a way to compute the global softmax normalization incrementally.



## 6. Online softmax: making tiling exact

This is the heart of the story.

Consider one query row `i`. Let its scores against all keys be:

\[
s = [s_1, s_2, \dots, s_T]
\]

If we processed all keys at once, we would compute a stable softmax using:

- `m = max(s)`
- `l = sum(exp(s - m))`

Now imagine we process keys in blocks. We want to update the same quantities as we discover more scores.

Suppose we have processed some keys already, and we have a running max `m_old`, a running normalizer `l_old`, and a running weighted sum `acc_old` (which will become the output numerator).

Now we receive a new block of scores `s_blk` and values `V_blk`.

Let:
- `m_blk = max(s_blk)`
- `m_new = max(m_old, m_blk)`

When the max changes, the scale of the exp terms changes. The trick is to rescale the old accumulators into the new coordinate system defined by `m_new`.

The updated normalizer becomes:

\[
l_{new} = e^{m_{old} - m_{new}} l_{old} + \sum \exp(s_{blk} - m_{new})
\]

The updated weighted sum becomes:

\[
acc_{new} = e^{m_{old} - m_{new}} acc_{old} + \exp(s_{blk} - m_{new}) V_{blk}
\]

At the end, the output row is simply `acc / l`. This gives an *exact* softmax over all keys, computed incrementally.



## 7. Version C: streaming attention with online softmax (FlashAttention core)

Now we implement the algorithm above.

We will tile across queries in blocks of `block_q` and keys/values in blocks of `block_k`.

For each query block, we maintain `(m, l, acc)` for every row in that block.

A detail that matters in practice is precision. Even if `Q, K, V` are fp16 or bf16, you typically compute the softmax and accumulators in fp32 to keep error under control. We’ll do the same.


In [None]:

def attention_streaming_online_softmax(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    causal: bool = False,
    block_q: int = 128,
    block_k: int = 128,
    softmax_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    """
    Streaming attention using tiled score computation + online softmax.
    This avoids allocating [B,H,T,T] and is mathematically equivalent to full softmax attention.

    Q, K, V: [B, H, T, D]
    Returns: [B, H, T, D]
    """
    assert Q.ndim == 4 and K.ndim == 4 and V.ndim == 4
    B, H, T, D = Q.shape
    assert K.shape == (B, H, T, D)
    assert V.shape == (B, H, T, D)

    scale = 1.0 / math.sqrt(D)
    out = torch.empty_like(V)

    Qc = Q.to(dtype=softmax_dtype)
    Kc = K.to(dtype=softmax_dtype)
    Vc = V.to(dtype=softmax_dtype)

    all_q_idx = torch.arange(T, device=Q.device)
    all_k_idx = torch.arange(T, device=Q.device)

    for q0 in range(0, T, block_q):
        q1 = min(q0 + block_q, T)
        q_len = q1 - q0

        Qblk = Qc[:, :, q0:q1, :]  # [B,H,q_len,D]

        m = torch.full((B, H, q_len), float("-inf"), device=Q.device, dtype=softmax_dtype)
        l = torch.zeros((B, H, q_len), device=Q.device, dtype=softmax_dtype)
        acc = torch.zeros((B, H, q_len, D), device=Q.device, dtype=softmax_dtype)

        q_idx = all_q_idx[q0:q1]

        for k0 in range(0, T, block_k):
            k1 = min(k0 + block_k, T)
            Kblk = Kc[:, :, k0:k1, :]  # [B,H,k_len,D]
            Vblk = Vc[:, :, k0:k1, :]  # [B,H,k_len,D]
            k_idx = all_k_idx[k0:k1]

            scores = torch.matmul(Qblk, Kblk.transpose(-2, -1)) * scale  # [B,H,q_len,k_len]

            if causal:
                causal_mask = (k_idx[None, :] > q_idx[:, None])  # [q_len,k_len]
                scores = scores.masked_fill(causal_mask[None, None, :, :], float("-inf"))

            m_blk = scores.max(dim=-1).values
            m_new = torch.maximum(m, m_blk)

            exp_scores = torch.exp(scores - m_new[..., None])

            alpha = torch.exp(m - m_new)
            l = alpha * l + exp_scores.sum(dim=-1)
            acc = alpha[..., None] * acc + torch.matmul(exp_scores, Vblk)

            m = m_new

        l_safe = torch.clamp(l, min=torch.finfo(softmax_dtype).tiny)
        Oblk = acc / l_safe[..., None]
        out[:, :, q0:q1, :] = Oblk.to(dtype=out.dtype)

    return out



## 8. Correctness: showing the versions agree

We now compare the streaming output to the naïve output across a range of shapes and dtypes, including sizes that are not multiples of the block sizes.


In [None]:

@dataclass
class TestCase:
    B: int
    H: int
    T: int
    D: int
    causal: bool
    dtype: torch.dtype

def max_rel_err(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> float:
    num = (a - b).abs().max().item()
    denom = (b.abs().max().item() + eps)
    return num / denom

def run_one_test(tc: TestCase, block_q=128, block_k=128):
    B,H,T,D = tc.B, tc.H, tc.T, tc.D
    Q = torch.randn(B,H,T,D, device=device, dtype=tc.dtype)
    K = torch.randn(B,H,T,D, device=device, dtype=tc.dtype)
    V = torch.randn(B,H,T,D, device=device, dtype=tc.dtype)

    ref = attention_naive(Q.float(), K.float(), V.float(), causal=tc.causal).to(tc.dtype)
    out = attention_streaming_online_softmax(Q, K, V, causal=tc.causal, block_q=block_q, block_k=block_k)

    abs_err = (out - ref).abs().max().item()
    rel_err = max_rel_err(out.float(), ref.float())
    return abs_err, rel_err

test_cases = [
    TestCase(B=1,H=1,T=128,D=64,  causal=False, dtype=torch.float32),
    TestCase(B=2,H=4,T=257,D=64,  causal=False, dtype=torch.float32),
    TestCase(B=1,H=2,T=256,D=128, causal=True,  dtype=torch.float32),
]

if device == "cuda":
    test_cases += [
        TestCase(B=2,H=4,T=513,D=64,  causal=False, dtype=torch.float16),
        TestCase(B=2,H=4,T=513,D=64,  causal=True,  dtype=torch.float16),
    ]
    try:
        _ = torch.randn(1, device="cuda", dtype=torch.bfloat16)
        test_cases += [
            TestCase(B=2,H=4,T=513,D=64,  causal=False, dtype=torch.bfloat16),
            TestCase(B=2,H=4,T=513,D=64,  causal=True,  dtype=torch.bfloat16),
        ]
    except Exception:
        pass

for tc in test_cases:
    abs_err, rel_err = run_one_test(tc, block_q=128, block_k=128)
    print(f"B={tc.B} H={tc.H} T={tc.T} D={tc.D} causal={tc.causal} dtype={tc.dtype} | "
          f"max_abs_err={abs_err:.3e} max_rel_err={rel_err:.3e}")



## 9. Benchmarking memory and runtime

The streaming version avoids allocating `[B,H,T,T]`. On GPU we can measure peak allocated memory during the call.

In pure PyTorch, runtime may or may not improve because the streaming version does more Python-level looping and launches more kernels. The real win of FlashAttention comes when these steps are fused into a single CUDA kernel that keeps intermediate values on-chip. Still, the memory reduction shows up immediately.


In [None]:

def benchmark(fn, warmup=5, iters=20):
    for _ in range(warmup):
        _ = fn()
    if device == "cuda":
        torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(iters):
        _ = fn()
    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.perf_counter()
    return (t1 - t0) / iters

def peak_mem_bytes(fn):
    if device != "cuda":
        return None
    torch.cuda.reset_peak_memory_stats()
    _ = fn()
    torch.cuda.synchronize()
    return torch.cuda.max_memory_allocated()

def run_bench_suite(B=2, H=8, D=64, dtype=None, causal=False, Ts=(512, 1024, 2048)):
    if dtype is None:
        dtype = torch.float16 if device == "cuda" else torch.float32

    rows = []
    for T in Ts:
        Q = torch.randn(B,H,T,D, device=device, dtype=dtype)
        K = torch.randn(B,H,T,D, device=device, dtype=dtype)
        V = torch.randn(B,H,T,D, device=device, dtype=dtype)

        naive_fn = lambda: attention_naive(Q.float(), K.float(), V.float(), causal=causal).to(dtype)
        stream_fn = lambda: attention_streaming_online_softmax(Q, K, V, causal=causal, block_q=128, block_k=128)

        naive_time = benchmark(naive_fn)
        stream_time = benchmark(stream_fn)

        naive_mem = peak_mem_bytes(naive_fn)
        stream_mem = peak_mem_bytes(stream_fn)

        rows.append((T, naive_time, stream_time, naive_mem, stream_mem))
    return rows

rows = run_bench_suite(B=2, H=8, D=64, causal=True, Ts=(512, 1024, 2048))

print("T | naive_ms | streaming_ms | naive_peak_mem_MB | streaming_peak_mem_MB")
for T, t_naive, t_stream, m_naive, m_stream in rows:
    naive_ms = t_naive * 1000
    stream_ms = t_stream * 1000
    if m_naive is None:
        print(f"{T:4d} | {naive_ms:8.3f} | {stream_ms:11.3f} | (cpu) | (cpu)")
    else:
        naive_mb = m_naive / (1024**2)
        stream_mb = m_stream / (1024**2)
        print(f"{T:4d} | {naive_ms:8.3f} | {stream_ms:11.3f} | {naive_mb:17.1f} | {stream_mb:20.1f}")



## 10. The bridge to CUDA (next post)

What we implemented is the algorithmic core: tiled computation plus online softmax.

The CUDA version will keep the same mathematical structure, but it will change *how* the work is scheduled:
it will load tiles into on-chip memory, fuse operations, reduce global memory traffic, and compute outputs efficiently.

A clean next step is to implement only the **forward** kernel first, validate it against the reference, and then move to backward once forward is solid.
