### FlashAttention — Summary

FlashAttention is an IO-aware, exact attention algorithm that speeds up Transformer attention by reducing expensive GPU memory traffic.
Traditional attention constructs and stores the full $N \times N$ attention matrix in GPU DRAM (HBM), causing attention to become memory-bound.

⸻

Key Idea

FlashAttention avoids materializing the full attention matrix.
Instead, it:
	•	Splits $Q$, $K$, $V$ into small tiles that fit in fast on-chip GPU SRAM
	•	Streams tiles through SRAM instead of DRAM
	•	Computes partial attention ($QK^\top$ → softmax → softmax·$V$) on-chip
	•	Uses online softmax to merge partial results
	•	Writes only the final output to DRAM

This reduces DRAM operations from $O(N^2)$ to $O(N \cdot d)$, making attention dramatically faster.

⸻

Why Online Softmax Works

Normal softmax for a row:

$$
\text{softmax}(s_i) = \frac{e^{s_i}}{\sum_j e^{s_j}}
$$

FlashAttention computes this incrementally by maintaining:
	•	Running max: $m$
	•	Running sum of exponentials: $l$
	•	Running output accumulator: $O$

For each tile:
	1.	Update running max
	2.	Rescale old statistics
	3.	Add new exponentials
	4.	Add contribution from $V$ tile

Because softmax is invariant to subtracting a constant:

$$
\frac{e^{s_i}}{\sum_j e^{s_j}}
$$

$$
\frac{e^{s_i - m}}{\sum_j e^{s_j - m}}
$$

the online version produces exactly the same softmax as processing all scores at once.

⸻

FlashAttention-2 Improvements

FlashAttention-2 improves performance via:
	•	More parallelism inside each head
	•	Warp specialization (some warps load, others compute)
	•	Double-buffering K/V tiles
	•	Better Tensor Core utilization
	•	Fewer synchronizations
	•	Better handling of variable sequence lengths

This yields ~2× speedup over FA-1 (and 2–4× over naïve attention).

⸻

Why It’s Faster

On-chip SRAM is:
	•	10–20× faster than HBM
	•	~100× lower latency

FlashAttention keeps almost all intermediate computation in SRAM, only reading Q/K/V once and writing the final output once.
This removes the $N \times N$ DRAM bottleneck and makes attention compute-bound, not memory-bound.

⸻

Implementation Overview

FlashAttention kernels:
	•	Load Q/K/V tiles from HBM → SRAM
	•	Compute $QK^\top$ using Tensor Cores
	•	Apply masking + online softmax in SRAM
	•	Multiply by V tiles
	•	Accumulate outputs
	•	Write final output $O$ to DRAM

The $N \times N$ matrix is never created.

PyTorch 2+ automatically uses FlashAttention when calling:
scaled_dot_product_attention(...)

on supported GPUs.

⸻

Benefits
	•	2–4× faster attention
	•	10× lower memory usage
	•	Enables 8k–64k token contexts
	•	Exact (not approximate)
	•	Helps both training and inference