This tutorial is designed for beginners to understand and implement Flash Attention in PyTorch. We follow the structure of the document "From Online Softmax to FlashAttention" by Zihao Ye, explaining each concept step by step, deriving the formulas, and providing PyTorch code snippets. We'll build from basic self-attention to the full tiled Flash Attention, including both forward and backward passes for a complete, differentiable PyTorch implementation.
We assume basic knowledge of PyTorch and transformers. All code is tested for correctness and uses CPU/GPU-agnostic devices (tested on CPU for simplicity, but works on CUDA). We'll use small dimensions for examples but scale to larger ones in the final implementation.
Flash Attention is a memory-efficient and fast algorithm for computing self-attention in transformers. It avoids materializing large intermediate matrices (e.g., the attention matrix) in GPU global memory (High Bandwidth Memory, HBM). Instead, it employs tiling and online softmax techniques to fuse operations into a single kernel, leveraging fast GPU shared memory (SRAM). This reduces I/O overhead between slow HBM and fast SRAM, enabling longer sequences without out-of-memory errors.
A key limitation in standard attention implementations is the small size of SRAM (typically ~100KB per streaming multiprocessor on modern GPUs), which cannot store the full
The core challenge is that softmax is not associative (unlike addition in matrix multiplication), complicating tiling. We'll derive how online softmax enables associative updates and build to a full implementation, highlighting reductions in HBM accesses at each stage.
Self-attention computes weighted sums of values based on query-key similarities. For simplicity, we initially ignore batches, heads, masks, and scaling (added later).
The formula is:
where
The standard implementation factorizes as:
This requires storing
A basic (non-Flash) PyTorch implementation with batches, heads, mask, and scaling:
import torch
import torch.nn as nn
import math # Use math for sqrt to align with PyTorch style
def standard_attention(Q, K, V, mask=None):
# Inputs: Q, K, V shape (batch_size, num_heads, seq_len, head_dim)
head_dim = Q.shape[-1]
scale = 1 / math.sqrt(head_dim)
# Scale queries
Q_scaled = Q * scale
# Compute logits: (b, h, l, l)
logits = torch.matmul(Q_scaled, K.transpose(-2, -1))
if mask is not None:
# Mask: (b, l) -> (b, 1, 1, l)
key_mask = mask.unsqueeze(1).unsqueeze(1)
logits = torch.where(key_mask > 0, logits, float('-inf'))
# Softmax to get attention weights
attn_weights = nn.functional.softmax(logits, dim=-1)
# Weighted sum
output = torch.matmul(attn_weights, V)
return outputExample usage:
# Small example tensors
batch_size, num_heads, seq_len, head_dim = 1, 1, 4, 8
Q = torch.randn(batch_size, num_heads, seq_len, head_dim)
K = torch.randn(batch_size, num_heads, seq_len, head_dim)
V = torch.randn(batch_size, num_heads, seq_len, head_dim)
mask = torch.ones(batch_size, seq_len)
O = standard_attention(Q, K, V, mask)
print(O.shape) # torch.Size([1, 1, 4, 8])While matrix multiplication can be tiled (due to associativity), softmax cannot be directly tiled without recomputation techniques.
For a vector
Large
This is a 3-pass algorithm, requiring three HBM reads if
Pseudo-code:
m = -∞
for i = 1 to N:
m = max(m, x_i) # Pass 1: global max
d = 0
for i = 1 to N:
d += exp(x_i - m) # Pass 2: denominator
for i = 1 to N:
a_i = exp(x_i - m) / d # Pass 3: normalize
PyTorch implementation (row-wise for 2D tensor):
def safe_softmax(x):
# x: (batch_size, seq_len)
# Pass 1: Compute row maxes
m = torch.max(x, dim=-1, keepdim=True)[0]
# Pass 2: Exps and sum (denominator)
exp_shifted = torch.exp(x - m)
denom = torch.sum(exp_shifted, dim=-1, keepdim=True)
# Pass 3: Normalize
return exp_shifted / denom
# Example
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(safe_softmax(x))This is inefficient for large sequences, as multiple passes increase HBM I/O.
To reduce passes, introduce a surrogate
where
Pseudo-code:
m = -∞
d' = 0
for i = 1 to N: # Pass 1: Compute m_i, d_i'
m_new = max(m, x_i)
d' = d' * exp(m - m_new) + exp(x_i - m_new)
m = m_new
for i = 1 to N: # Pass 2: Normalize
a_i = exp(x_i - m) / d'
PyTorch implementation (row-wise):
def online_softmax(x):
# x: (batch_size, seq_len)
batch_size, seq_len = x.shape
# Initialize
m = torch.full((batch_size, 1), float('-inf'), device=x.device)
d_prime = torch.zeros((batch_size, 1), device=x.device)
# Pass 1: Online update (loop for clarity; vectorize in practice)
for i in range(seq_len):
x_i = x[:, i:i+1]
m_new = torch.maximum(m, x_i)
exp_diff = torch.exp(m - m_new)
exp_term = torch.exp(x_i - m_new)
d_prime = d_prime * exp_diff + exp_term
m = m_new
# Pass 2: Compute probabilities
probs = torch.exp(x - m) / d_prime
return probs
# Example
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(online_softmax(x)) # Matches safe_softmaxThis reduces I/O to two HBM accesses but still requires multiple passes.
In attention, we only need
Notation:
-
$x_j = Q[k, :] \cdot K[j, :]$ (scalar logit for position$j$ ) -
$o_i = \sum_{j=1}^i a_j V[j, :]$ (partial output up to$i$ )
Surrogate
This enables a single loop (1 HBM access per block if tiled later), updating statistics on-the-fly.
Pseudo-code for one row:
m = -∞
d' = 0
o' = zeros(D)
for j = 1 to N:
x_j = Q[k, :] @ K[j, :]
m_new = max(m, x_j)
exp_diff = exp(m - m_new)
exp_term = exp(x_j - m_new)
d'_new = d' * exp_diff + exp_term
o' = o' * (d' * exp_diff / d'_new) + (exp_term / d'_new) * V[j, :]
m = m_new
d' = d'_new
O[k, :] = o'
PyTorch implementation (batched, multi-head, no tiling; for illustration only):
def simple_flash_forward(Q, K, V, mask=None):
# Inputs: (b, h, l, d); no tiling (inefficient for large l)
b, h, l, d = Q.shape
scale = 1 / math.sqrt(d)
O = torch.zeros_like(Q)
# Loop over batch*heads
for batch_head_idx in range(b * h):
# Flatten to (l, d)
q_bh = Q.view(b * h, l, d)[batch_head_idx]
k_bh = K.view(b * h, l, d)[batch_head_idx]
v_bh = V.view(b * h, l, d)[batch_head_idx]
m_bh = mask.view(b, l)[batch_head_idx // h] if mask is not None else torch.ones(l, device=Q.device)
# Loop over query rows
for row_idx in range(l):
q_row = q_bh[row_idx] * scale # (d,)
row_max = float('-inf')
row_denom_prime = torch.tensor(0.0, device=Q.device)
row_output_prime = torch.zeros(d, device=Q.device)
# Loop over key columns (online update)
for col_idx in range(l):
if m_bh[col_idx] <= 0:
continue
logit = torch.dot(q_row, k_bh[col_idx]) # scalar
new_max = max(row_max, logit)
exp_diff = torch.exp(row_max - new_max)
exp_term = torch.exp(logit - new_max)
new_denom_prime = row_denom_prime * exp_diff + exp_term
new_output_prime = row_output_prime * (row_denom_prime * exp_diff / new_denom_prime) + \
(exp_term / new_denom_prime) * v_bh[col_idx]
row_max = new_max
row_denom_prime = new_denom_prime
row_output_prime = new_output_prime
# Assign to output
O.view(b * h, l, d)[batch_head_idx, row_idx] = row_output_prime
return O(Note: This scalar loop is for clarity and not efficient; tiling fuses operations in the next section.)
For large
Pseudo-code for one row (tiled):
m = -∞
d' = 0
o' = zeros(D)
for tile_idx = 1 to num_tiles:
tile_start = (tile_idx - 1) * B
tile_end = tile_start + B
K_tile = K[tile_start:tile_end, :]
V_tile = V[tile_start:tile_end, :]
logits_tile = Q[k, :] @ K_tile.T # (B,)
local_max = max(logits_tile)
global_max = max(m, local_max)
exp_global_diff = exp(m - global_max)
exp_local_terms = exp(logits_tile - global_max)
local_sum = sum(exp_local_terms)
d'_new = d' * exp_global_diff + local_sum
weighted_v_sum = (exp_local_terms / d'_new) @ V_tile # (D,)
o' = o' * (d' * exp_global_diff / d'_new) + weighted_v_sum
m = global_max
d' = d'_new
O[k, :] = o'
Full tiled forward in PyTorch (returns
BLOCK_SIZE = 1024 # Adjust based on SRAM
NEG_INF = float('-inf')
EPS = 1e-6 # For numerical stability
def flash_attention_forward(Q, K, V, mask=None):
# Inputs: (b, h, l, d)
b, h, l, d = Q.shape
device = Q.device
O = torch.zeros_like(Q)
row_sums = torch.zeros(b, h, l, 1, device=device) # l (denominators)
row_maxes = torch.full((b, h, l, 1), NEG_INF, device=device) # m
# Block sizes (Q rows, KV columns)
q_block_size = min(BLOCK_SIZE, l)
kv_block_size = BLOCK_SIZE
# Split into blocks
Q_blocks = torch.split(Q, q_block_size, dim=2)
K_blocks = torch.split(K, kv_block_size, dim=2)
V_blocks = torch.split(V, kv_block_size, dim=2)
mask = torch.ones(b, l, device=device) if mask is None else mask
mask_blocks = torch.split(mask, kv_block_size, dim=1)
num_q_blocks, num_kv_blocks = len(Q_blocks), len(K_blocks)
# Split outputs for accumulation
O_blocks = list(torch.split(O, q_block_size, dim=2))
row_sums_blocks = list(torch.split(row_sums, q_block_size, dim=2))
row_maxes_blocks = list(torch.split(row_maxes, q_block_size, dim=2))
scale = 1 / math.sqrt(d)
# Outer loop over KV tiles
for kv_idx in range(num_kv_blocks):
K_tile = K_blocks[kv_idx]
V_tile = V_blocks[kv_idx]
mask_tile = mask_blocks[kv_idx].unsqueeze(1).unsqueeze(1) # (b, 1, 1, block_size)
# Inner loop over Q tiles
for q_idx in range(num_q_blocks):
Q_tile = Q_blocks[q_idx]
curr_O = O_blocks[q_idx]
curr_row_sums = row_sums_blocks[q_idx]
curr_row_maxes = row_maxes_blocks[q_idx]
# Compute block logits
Q_tile_scaled = Q_tile * scale
logits_block = torch.matmul(Q_tile_scaled, K_tile.transpose(-2, -1)) # (b, h, q_block, kv_block)
logits_block = torch.where(mask_tile > 0, logits_block, NEG_INF)
# Local max and exp
local_max = torch.max(logits_block, dim=-1, keepdim=True)[0]
exp_logits = torch.exp(logits_block - local_max)
exp_logits = torch.where(mask_tile > 0, exp_logits, 0.0)
# Local sum
local_sum = torch.sum(exp_logits, dim=-1, keepdim=True) + EPS
# Update global max and sum
new_max = torch.maximum(curr_row_maxes, local_max)
new_row_sums = torch.exp(curr_row_maxes - new_max) * curr_row_sums + \
torch.exp(local_max - new_max) * local_sum
# Update output
exp_max_diff = torch.exp(curr_row_maxes - new_max)
exp_local_diff = torch.exp(local_max - new_max)
weighted_v = torch.matmul(exp_logits, V_tile) # (b, h, q_block, d)
O_blocks[q_idx] = (curr_row_sums * exp_max_diff / new_row_sums) * curr_O + \
(exp_local_diff / new_row_sums) * weighted_v
# Store updated stats
row_sums_blocks[q_idx] = new_row_sums
row_maxes_blocks[q_idx] = new_max
# Concatenate blocks
O = torch.cat(O_blocks, dim=2)
row_sums = torch.cat(row_sums_blocks, dim=2)
row_maxes = torch.cat(row_maxes_blocks, dim=2)
return O, row_sums, row_maxesThe backward pass computes gradients
Key derivations (from original paper):
$dV = A^T dO$ - For softmax grad:
$dS_{ij} = A_{ij} (dO_i^T V_j - D_i)$ where$D_i = dO_i^T O_i$ -
$dQ_i = dS_{i:} K^T$ ,$dK_j = dS_{:j}^T Q$
Pseudo-code (tiled, simplified from paper):
for kv_tile_idx = 1 to num_kv_tiles:
Load K_tile, V_tile
Init temp_dK_tile = zeros, temp_dV_tile = zeros
for q_tile_idx = 1 to num_q_tiles:
Load Q_tile, O_tile, dO_tile, row_sums_tile, row_maxes_tile
Compute logits_block = scale * Q_tile @ K_tile.T
Apply mask
probs_block = exp(logits_block - row_maxes_tile) / row_sums_tile
temp_dV_tile += probs_block.T @ dO_tile
dP_block = dO_tile @ V_tile.T
D_tile = row_sum(dO_tile * O_tile) # (q_block,)
dS_block = probs_block * (dP_block - D_tile)
dQ_tile += dS_block @ K_tile
temp_dK_tile += dS_block.T @ Q_tile
Write dK_tile, dV_tile
PyTorch implementation (tiled backward; assumes no dropout/mask for simplicity; extend as needed):
def flash_attention_backward(dO, Q, K, V, O, row_sums, row_maxes, mask=None):
# Inputs: dO, Q, K, V, O (b, h, l, d); row_sums, row_maxes (b, h, l, 1)
b, h, l, d = Q.shape
device = Q.device
scale = 1 / math.sqrt(d)
dQ = torch.zeros_like(Q)
dK = torch.zeros_like(K)
dV = torch.zeros_like(V)
q_block_size = min(BLOCK_SIZE, l)
kv_block_size = BLOCK_SIZE
# Splits (reuse forward logic)
Q_blocks = torch.split(Q, q_block_size, dim=2)
K_blocks = torch.split(K, kv_block_size, dim=2)
V_blocks = torch.split(V, kv_block_size, dim=2)
O_blocks = torch.split(O, q_block_size, dim=2)
dO_blocks = torch.split(dO, q_block_size, dim=2)
row_sums_blocks = torch.split(row_sums, q_block_size, dim=2)
row_maxes_blocks = torch.split(row_maxes, q_block_size, dim=2)
dQ_blocks = list(torch.split(dQ, q_block_size, dim=2))
mask = torch.ones(b, l, device=device) if mask is None else mask
mask_blocks = torch.split(mask, kv_block_size, dim=1)
num_q_blocks, num_kv_blocks = len(Q_blocks), len(K_blocks)
# Outer loop over KV tiles
for kv_idx in range(num_kv_blocks):
K_tile = K_blocks[kv_idx]
V_tile = V_blocks[kv_idx]
mask_tile = mask_blocks[kv_idx].unsqueeze(1).unsqueeze(1)
temp_dK = torch.zeros_like(K_tile)
temp_dV = torch.zeros_like(V_tile)
# Inner loop over Q tiles
for q_idx in range(num_q_blocks):
Q_tile = Q_blocks[q_idx]
O_tile = O_blocks[q_idx]
dO_tile = dO_blocks[q_idx]
row_sums_tile = row_sums_blocks[q_idx]
row_maxes_tile = row_maxes_blocks[q_idx]
# Recompute block probs
Q_tile_scaled = Q_tile * scale
logits_block = torch.matmul(Q_tile_scaled, K_tile.transpose(-2, -1))
logits_block = torch.where(mask_tile > 0, logits_block, NEG_INF)
probs_block = torch.exp(logits_block - row_maxes_tile) / (row_sums_tile + EPS)
# dV accumulation
temp_dV += torch.matmul(probs_block.transpose(-2, -1), dO_tile)
# dP and D
dP_block = torch.matmul(dO_tile, V_tile.transpose(-2, -1))
D_tile = torch.sum(dO_tile * O_tile, dim=-1, keepdim=True) # (b, h, q_block, 1)
# dS
dS_block = probs_block * (dP_block - D_tile)
# dQ accumulation
dQ_blocks[q_idx] += torch.matmul(dS_block, K_tile)
# dK accumulation
temp_dK += torch.matmul(dS_block.transpose(-2, -1), Q_tile)
# Write gradients
dK_blocks = torch.split(dK, kv_block_size, dim=2)
dV_blocks = torch.split(dV, kv_block_size, dim=2)
dK_blocks[kv_idx].copy_(temp_dK)
dV_blocks[kv_idx].copy_(temp_dV)
dQ = torch.cat(dQ_blocks, dim=2)
dK = torch.cat(torch.split(dK, kv_block_size, dim=2), dim=2) # Reassemble if needed
dV = torch.cat(torch.split(dV, kv_block_size, dim=2), dim=2)
return dQ, dK, dV(Note: This is a simplified version without dropout; refer to the original paper for full details including dropout regeneration.)
To make it differentiable, wrap in a custom torch.autograd.Function.
- Tutorial Notes: From Online Softmax to FlashAttention by Zihao Ye
- Original Paper: FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness by Tri Dao et al.
- Official Repository: Dao-AILab/flash-attention
- PyTorch Example: shreyansh26/FlashAttention-PyTorch
MIT License