A from-scratch, handwritten Flash Attention 2 CUDA kernel targeting NVIDIA Blackwell architecture (RTX 5060 Laptop, SM120). Fully exploits hardware features including TMA, MMA, shared memory swizzle, double buffering, and warp specialization — progressively optimized to near-peak performance.
| Config (B, H, N, d) | BlackFlash (TFLOPS) | cuDNN (TFLOPS) | Ratio |
|---|---|---|---|
| 2, 2, 512, 64 | 4.54 | 4.42 | 1.05x |
| 4, 8, 1024, 64 | 29.8 | 41.2 | 0.73x |
| 8, 16, 1024, 64 | 30.6 | 48.1 | 0.64x |
| 64, 64, 1024, 64 | 31.45 | 51.14 | 0.62x |
- Flash Attention 2 forward pass with online softmax and tiled computation
- TMA (Tensor Memory Accelerator) for asynchronous global-to-shared memory data movement
- MMA instructions for matrix multiply-accumulate on Tensor Cores
- 128B shared memory swizzle to eliminate bank conflicts
- Double buffering pipeline overlapping computation with data movement
- Warp specialization with producer/consumer division of labor
Tile and thread configuration:
- Br = Bc = d = 64
- Block size = 288 threads (9 warps total): 1 producer warp (32 threads) + 8 consumer warps (256 threads)
- 2-stage double buffering for K/V tiles (sK0/sV0 and sK1/sV1)
Shared memory layout (per block):
|--- sQ ---|--- sK0 ---|--- sV0 ---|--- sK1 ---|--- sV1 ---|--- barriers & scratch ---|
Br × d Bc × d Bc × d Bc × d Bc × d 5 mbarriers + 4·Br floats
(64×64) (64×64) (64×64) (64×64) (64×64) (max/sum reduction)
All Q/K/V tiles are loaded via TMA with 128B swizzle addressing. The P matrix (softmax output) reuses the current K buffer (sP = sK_cur) to avoid extra shared memory allocation.
Consumer warp organization:
The 8 consumer warps are organized as a 4×2 grid. Each warp_pair (pair index 0–3) owns 16 rows of the Q tile. Within each pair, warp_half 0 and 1 split the column dimension, so every element of the 64×64 output tile is covered.
half=0 half=1
pair=0: warp 0 warp 1 → Q rows 0–15
pair=1: warp 2 warp 3 → Q rows 16–31
pair=2: warp 4 warp 5 → Q rows 32–47
pair=3: warp 6 warp 7 → Q rows 48–63
Pipeline execution flow:
-
Q load — The producer warp (warp 0) issues a TMA load for the Q tile into
sQ, then immediately begins the K/V loading pipeline. Consumers wait onmbar_qfor Q data to arrive. -
K/V streaming (producer) — The producer iterates over all K/V tile pairs. For each iteration
j:- If
j >= 2, wait onmbar_emptyto ensure consumers have finished reading the buffer from 2 iterations ago (double buffering). - Issue TMA loads for K and V into the current stage (
sK0/sV0orsK1/sV1), signalingmbar_fullupon completion. - After all K/V tiles are dispatched, the producer warp exits.
- If
-
S = Q × Kᵀ (consumers) — Consumers wait on
mbar_fullfor the current K/V stage. Each consumer warp loads Q and K fragments from shared memory vialdmatrix(with swizzle-aware addressing), then performsm16n8k16HMMA to accumulate S tiles. The result is scaled by1/√d. -
Online softmax — Each warp computes a local row-max across its S fragment using warp shuffle reductions (
__shfl_xor_sync). Partial max values are exchanged betweenwarp_halfpairs through shared memory scratch space to obtain the true row-maxm_ij. Then:- Compute
m_new = max(m_i, m_ij)(running max across all K/V tiles so far) - Compute
α = exp(m_i − m_new)(rescale factor for prior accumulator) - Compute
P = exp(S − m_new)(softmax numerator) - Row-sum of P is similarly reduced across warp halves
- Compute
-
P writeback to SMEM — P values are converted from fp32 to bf16 and written to
sP(which aliasessK_cur) with manually applied swizzle addressing. This is necessary because P is not loaded via TMA — the swizzle layout must match the subsequentldmatrixreads. -
O = P × V (consumers) — P and V fragments are loaded via
ldmatrix(x4 for P, x2.trans for V) with swizzle addressing, then accumulated through HMMA. The running output accumulator is rescaled:acc = α · acc + P × V. -
Output writeback — After all K/V tiles are processed, each consumer thread writes its final output:
O = acc / l_i(fp32 → bf16 conversion). Results are first staged insQ(reused as output buffer), then cooperatively copied to global memory via 8-byte vectorized stores.
Synchronization mechanisms:
mbar_q(arrive count = 1): signals Q tile arrival from TMAmbar_full0/1(arrive count = 1): signals K/V tile arrival in each double buffer stagembar_empty0/1(arrive count = 256): signals consumers have finished reading a K/V stage, allowing the producer to reuse the bufferbar.sync 1, 256: intra-consumer-group barriers for shared memory data exchange (max/sum reduction, P writeback)
Each version was entirely handwritten and profiled with ncu to guide the next optimization:
- Naive FA2 — Baseline tiled Flash Attention implementation
- SMEM vectorized loads — Vectorized shared memory read/write
- MMA compute — Replaced scalar arithmetic with Tensor Core MMA instructions
- TMA data movement — Switched global→shared loads to asynchronous TMA
- Swizzle — 128B swizzle address mapping to eliminate shared memory bank conflicts
- Double buffering — Overlapped compute and data movement via pipeline
- Warp specialization — Producer/consumer warp partitioning to further hide latency
- NVIDIA GPU: Blackwell architecture (SM120), e.g. RTX 5060 Laptop
- CUDA Toolkit >= 12.8
- Compiler: nvcc
#run make clean && make run
#look ./build/flash_attn profile/ref_data
#compare blackflash and cuDNN python3 tests/bench_blackflash.py