Skip to content

gitctrlx/flash-attention-tutorial

Repository files navigation

Flash Attention Tutorial

This tutorial is designed for beginners to understand and implement Flash Attention in PyTorch. We follow the structure of the document "From Online Softmax to FlashAttention" by Zihao Ye, explaining each concept step by step, deriving the formulas, and providing PyTorch code snippets. We'll build from basic self-attention to the full tiled Flash Attention, including both forward and backward passes for a complete, differentiable PyTorch implementation.

We assume basic knowledge of PyTorch and transformers. All code is tested for correctness and uses CPU/GPU-agnostic devices (tested on CPU for simplicity, but works on CUDA). We'll use small dimensions for examples but scale to larger ones in the final implementation.

Introduction

Flash Attention is a memory-efficient and fast algorithm for computing self-attention in transformers. It avoids materializing large intermediate matrices (e.g., the attention matrix) in GPU global memory (High Bandwidth Memory, HBM). Instead, it employs tiling and online softmax techniques to fuse operations into a single kernel, leveraging fast GPU shared memory (SRAM). This reduces I/O overhead between slow HBM and fast SRAM, enabling longer sequences without out-of-memory errors.

A key limitation in standard attention implementations is the small size of SRAM (typically ~100KB per streaming multiprocessor on modern GPUs), which cannot store the full $O(L^2)$ attention matrix for long sequences $L$. As a result, computations require multiple HBM accesses: writing intermediates to HBM, then reading them back for subsequent operations. This I/O bottleneck slows down performance, especially for large $L$. Flash Attention addresses this by recomputing intermediates on-the-fly in SRAM using tiled blocks, minimizing HBM accesses to sub-quadratic levels.

The core challenge is that softmax is not associative (unlike addition in matrix multiplication), complicating tiling. We'll derive how online softmax enables associative updates and build to a full implementation, highlighting reductions in HBM accesses at each stage.

1. Standard Self-Attention

Self-attention computes weighted sums of values based on query-key similarities. For simplicity, we initially ignore batches, heads, masks, and scaling (added later).

The formula is:

$$ O = \mathrm{softmax}(Q K^T) V $$

where $Q, K, V, O \in \mathbb{R}^{L \times D}$, $L$ is the sequence length, and $D$ is the head dimension. Softmax is applied row-wise.

The standard implementation factorizes as:

$$ X = Q K^T, \quad A = \mathrm{softmax}(X), \quad O = A V $$

This requires storing $X$ and $A$ ($O(L^2)$ memory), which is inefficient for large $L$. In terms of memory accesses, it involves at least three HBM round-trips: (1) compute and write $X$ to HBM, (2) read $X$ from HBM for softmax and write $A$ to HBM, (3) read $A$ from HBM to compute $O$.

A basic (non-Flash) PyTorch implementation with batches, heads, mask, and scaling:

import torch
import torch.nn as nn
import math  # Use math for sqrt to align with PyTorch style

def standard_attention(Q, K, V, mask=None):
    # Inputs: Q, K, V shape (batch_size, num_heads, seq_len, head_dim)
    head_dim = Q.shape[-1]
    scale = 1 / math.sqrt(head_dim)
    # Scale queries
    Q_scaled = Q * scale
    # Compute logits: (b, h, l, l)
    logits = torch.matmul(Q_scaled, K.transpose(-2, -1))
    
    if mask is not None:
        # Mask: (b, l) -> (b, 1, 1, l)
        key_mask = mask.unsqueeze(1).unsqueeze(1)
        logits = torch.where(key_mask > 0, logits, float('-inf'))
    
    # Softmax to get attention weights
    attn_weights = nn.functional.softmax(logits, dim=-1)
    # Weighted sum
    output = torch.matmul(attn_weights, V)
    return output

Example usage:

# Small example tensors
batch_size, num_heads, seq_len, head_dim = 1, 1, 4, 8
Q = torch.randn(batch_size, num_heads, seq_len, head_dim)
K = torch.randn(batch_size, num_heads, seq_len, head_dim)
V = torch.randn(batch_size, num_heads, seq_len, head_dim)
mask = torch.ones(batch_size, seq_len)
O = standard_attention(Q, K, V, mask)
print(O.shape)  # torch.Size([1, 1, 4, 8])

While matrix multiplication can be tiled (due to associativity), softmax cannot be directly tiled without recomputation techniques.

2. Safe Softmax

For a vector $x = [x_1, \dots, x_N]$:

$$ \mathrm{softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}} $$

Large $x_i$ can cause numerical overflow (e.g., in float16). Safe softmax subtracts the row max $m = \max_j x_j$:

$$ \mathrm{softmax}(x)_i = \frac{e^{x_i - m}}{\sum_{j=1}^N e^{x_j - m}} $$

This is a 3-pass algorithm, requiring three HBM reads if $x$ does not fit in SRAM: (1) find max, (2) compute exp and sum, (3) normalize. Each pass reloads $x$ from HBM.

Pseudo-code:

m = -∞
for i = 1 to N:
    m = max(m, x_i)  # Pass 1: global max

d = 0
for i = 1 to N:
    d += exp(x_i - m)  # Pass 2: denominator

for i = 1 to N:
    a_i = exp(x_i - m) / d  # Pass 3: normalize

PyTorch implementation (row-wise for 2D tensor):

def safe_softmax(x):
    # x: (batch_size, seq_len)
    # Pass 1: Compute row maxes
    m = torch.max(x, dim=-1, keepdim=True)[0]
    # Pass 2: Exps and sum (denominator)
    exp_shifted = torch.exp(x - m)
    denom = torch.sum(exp_shifted, dim=-1, keepdim=True)
    # Pass 3: Normalize
    return exp_shifted / denom

# Example
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(safe_softmax(x))

This is inefficient for large sequences, as multiple passes increase HBM I/O.

3. Online Softmax

To reduce passes, introduce a surrogate $d_i' = \sum_{j=1}^i e^{x_j - m_i}$ with the recurrence:

$$ d_i' = d_{i-1}' \cdot e^{m_{i-1} - m_i} + e^{x_i - m_i} $$

where $m_i = \max(m_{i-1}, x_i)$. At $i=N$, $d_N' = \sum e^{x_j - m_N}$. This enables a 2-pass algorithm (2 HBM reads if not in SRAM): one for statistics, one for normalization.

Pseudo-code:

m = -∞
d' = 0
for i = 1 to N:  # Pass 1: Compute m_i, d_i'
    m_new = max(m, x_i)
    d' = d' * exp(m - m_new) + exp(x_i - m_new)
    m = m_new

for i = 1 to N:  # Pass 2: Normalize
    a_i = exp(x_i - m) / d'

PyTorch implementation (row-wise):

def online_softmax(x):
    # x: (batch_size, seq_len)
    batch_size, seq_len = x.shape
    # Initialize
    m = torch.full((batch_size, 1), float('-inf'), device=x.device)
    d_prime = torch.zeros((batch_size, 1), device=x.device)
    # Pass 1: Online update (loop for clarity; vectorize in practice)
    for i in range(seq_len):
        x_i = x[:, i:i+1]
        m_new = torch.maximum(m, x_i)
        exp_diff = torch.exp(m - m_new)
        exp_term = torch.exp(x_i - m_new)
        d_prime = d_prime * exp_diff + exp_term
        m = m_new
    # Pass 2: Compute probabilities
    probs = torch.exp(x - m) / d_prime
    return probs

# Example
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(online_softmax(x))  # Matches safe_softmax

This reduces I/O to two HBM accesses but still requires multiple passes.

4. Flash Attention (Without Tiling)

In attention, we only need $O = A V$, not $A$ itself. For each row $k$, derive a one-pass recurrence using surrogates.

Notation:

  • $x_j = Q[k, :] \cdot K[j, :]$ (scalar logit for position $j$)
  • $o_i = \sum_{j=1}^i a_j V[j, :]$ (partial output up to $i$)

Surrogate $o_i' = \sum_{j=1}^i e^{x_j - m_i} V[j, :] / d_i'$, with recurrence:

$$ o_i' = o_{i-1}' \cdot \frac{d_{i-1}' e^{m_{i-1} - m_i}}{d_i'} + \frac{e^{x_i - m_i}}{d_i'} V[i, :] $$

This enables a single loop (1 HBM access per block if tiled later), updating statistics on-the-fly.

Pseudo-code for one row:

m = -∞
d' = 0
o' = zeros(D)
for j = 1 to N:
    x_j = Q[k, :] @ K[j, :]
    m_new = max(m, x_j)
    exp_diff = exp(m - m_new)
    exp_term = exp(x_j - m_new)
    d'_new = d' * exp_diff + exp_term
    o' = o' * (d' * exp_diff / d'_new) + (exp_term / d'_new) * V[j, :]
    m = m_new
    d' = d'_new
O[k, :] = o'

PyTorch implementation (batched, multi-head, no tiling; for illustration only):

def simple_flash_forward(Q, K, V, mask=None):
    # Inputs: (b, h, l, d); no tiling (inefficient for large l)
    b, h, l, d = Q.shape
    scale = 1 / math.sqrt(d)
    O = torch.zeros_like(Q)
    # Loop over batch*heads
    for batch_head_idx in range(b * h):
        # Flatten to (l, d)
        q_bh = Q.view(b * h, l, d)[batch_head_idx]
        k_bh = K.view(b * h, l, d)[batch_head_idx]
        v_bh = V.view(b * h, l, d)[batch_head_idx]
        m_bh = mask.view(b, l)[batch_head_idx // h] if mask is not None else torch.ones(l, device=Q.device)
        # Loop over query rows
        for row_idx in range(l):
            q_row = q_bh[row_idx] * scale  # (d,)
            row_max = float('-inf')
            row_denom_prime = torch.tensor(0.0, device=Q.device)
            row_output_prime = torch.zeros(d, device=Q.device)
            # Loop over key columns (online update)
            for col_idx in range(l):
                if m_bh[col_idx] <= 0:
                    continue
                logit = torch.dot(q_row, k_bh[col_idx])  # scalar
                new_max = max(row_max, logit)
                exp_diff = torch.exp(row_max - new_max)
                exp_term = torch.exp(logit - new_max)
                new_denom_prime = row_denom_prime * exp_diff + exp_term
                new_output_prime = row_output_prime * (row_denom_prime * exp_diff / new_denom_prime) + \
                                   (exp_term / new_denom_prime) * v_bh[col_idx]
                row_max = new_max
                row_denom_prime = new_denom_prime
                row_output_prime = new_output_prime
            # Assign to output
            O.view(b * h, l, d)[batch_head_idx, row_idx] = row_output_prime
    return O

(Note: This scalar loop is for clarity and not efficient; tiling fuses operations in the next section.)

5. Flash Attention (With Tiling)

For large $L$, tile $K$ and $V$ into blocks of size $B$ (fitting in SRAM). For each query row/block, load tiles, compute local max/exp/sum in SRAM, and update global statistics associatively. This ensures a single effective pass, with HBM accesses scaling as $O(L^2 D^2 / M)$ (where $M$ is SRAM size), far fewer than standard attention's $O(L^2 + L D)$.

Pseudo-code for one row (tiled):

m = -∞
d' = 0
o' = zeros(D)
for tile_idx = 1 to num_tiles:
    tile_start = (tile_idx - 1) * B
    tile_end = tile_start + B
    K_tile = K[tile_start:tile_end, :]
    V_tile = V[tile_start:tile_end, :]
    logits_tile = Q[k, :] @ K_tile.T  # (B,)
    local_max = max(logits_tile)
    global_max = max(m, local_max)
    exp_global_diff = exp(m - global_max)
    exp_local_terms = exp(logits_tile - global_max)
    local_sum = sum(exp_local_terms)
    d'_new = d' * exp_global_diff + local_sum
    weighted_v_sum = (exp_local_terms / d'_new) @ V_tile  # (D,)
    o' = o' * (d' * exp_global_diff / d'_new) + weighted_v_sum
    m = global_max
    d' = d'_new
O[k, :] = o'

Full tiled forward in PyTorch (returns $O$, row sums $l$, row maxes $m$ for backward):

BLOCK_SIZE = 1024  # Adjust based on SRAM
NEG_INF = float('-inf')
EPS = 1e-6  # For numerical stability

def flash_attention_forward(Q, K, V, mask=None):
    # Inputs: (b, h, l, d)
    b, h, l, d = Q.shape
    device = Q.device
    O = torch.zeros_like(Q)
    row_sums = torch.zeros(b, h, l, 1, device=device)  # l (denominators)
    row_maxes = torch.full((b, h, l, 1), NEG_INF, device=device)  # m
    
    # Block sizes (Q rows, KV columns)
    q_block_size = min(BLOCK_SIZE, l)
    kv_block_size = BLOCK_SIZE
    
    # Split into blocks
    Q_blocks = torch.split(Q, q_block_size, dim=2)
    K_blocks = torch.split(K, kv_block_size, dim=2)
    V_blocks = torch.split(V, kv_block_size, dim=2)
    mask = torch.ones(b, l, device=device) if mask is None else mask
    mask_blocks = torch.split(mask, kv_block_size, dim=1)
    
    num_q_blocks, num_kv_blocks = len(Q_blocks), len(K_blocks)
    
    # Split outputs for accumulation
    O_blocks = list(torch.split(O, q_block_size, dim=2))
    row_sums_blocks = list(torch.split(row_sums, q_block_size, dim=2))
    row_maxes_blocks = list(torch.split(row_maxes, q_block_size, dim=2))
    
    scale = 1 / math.sqrt(d)
    
    # Outer loop over KV tiles
    for kv_idx in range(num_kv_blocks):
        K_tile = K_blocks[kv_idx]
        V_tile = V_blocks[kv_idx]
        mask_tile = mask_blocks[kv_idx].unsqueeze(1).unsqueeze(1)  # (b, 1, 1, block_size)
        
        # Inner loop over Q tiles
        for q_idx in range(num_q_blocks):
            Q_tile = Q_blocks[q_idx]
            curr_O = O_blocks[q_idx]
            curr_row_sums = row_sums_blocks[q_idx]
            curr_row_maxes = row_maxes_blocks[q_idx]
            
            # Compute block logits
            Q_tile_scaled = Q_tile * scale
            logits_block = torch.matmul(Q_tile_scaled, K_tile.transpose(-2, -1))  # (b, h, q_block, kv_block)
            logits_block = torch.where(mask_tile > 0, logits_block, NEG_INF)
            
            # Local max and exp
            local_max = torch.max(logits_block, dim=-1, keepdim=True)[0]
            exp_logits = torch.exp(logits_block - local_max)
            exp_logits = torch.where(mask_tile > 0, exp_logits, 0.0)
            
            # Local sum
            local_sum = torch.sum(exp_logits, dim=-1, keepdim=True) + EPS
            
            # Update global max and sum
            new_max = torch.maximum(curr_row_maxes, local_max)
            new_row_sums = torch.exp(curr_row_maxes - new_max) * curr_row_sums + \
                           torch.exp(local_max - new_max) * local_sum
            
            # Update output
            exp_max_diff = torch.exp(curr_row_maxes - new_max)
            exp_local_diff = torch.exp(local_max - new_max)
            weighted_v = torch.matmul(exp_logits, V_tile)  # (b, h, q_block, d)
            O_blocks[q_idx] = (curr_row_sums * exp_max_diff / new_row_sums) * curr_O + \
                              (exp_local_diff / new_row_sums) * weighted_v
            
            # Store updated stats
            row_sums_blocks[q_idx] = new_row_sums
            row_maxes_blocks[q_idx] = new_max
    
    # Concatenate blocks
    O = torch.cat(O_blocks, dim=2)
    row_sums = torch.cat(row_sums_blocks, dim=2)
    row_maxes = torch.cat(row_maxes_blocks, dim=2)
    return O, row_sums, row_maxes

6. Flash Attention Backward Pass

The backward pass computes gradients $dQ$, $dK$, $dV$ without storing the $O(L^2)$ attention matrix. It recomputes block-wise in SRAM using saved row maxes $m$ and sums $l$ from forward, with similar tiling. This maintains memory efficiency, with HBM accesses $O(L^2 D^2 / M)$.

Key derivations (from original paper):

  • $dV = A^T dO$
  • For softmax grad: $dS_{ij} = A_{ij} (dO_i^T V_j - D_i)$ where $D_i = dO_i^T O_i$
  • $dQ_i = dS_{i:} K^T$, $dK_j = dS_{:j}^T Q$

Pseudo-code (tiled, simplified from paper):

for kv_tile_idx = 1 to num_kv_tiles:
    Load K_tile, V_tile
    Init temp_dK_tile = zeros, temp_dV_tile = zeros
    for q_tile_idx = 1 to num_q_tiles:
        Load Q_tile, O_tile, dO_tile, row_sums_tile, row_maxes_tile
        Compute logits_block = scale * Q_tile @ K_tile.T
        Apply mask
        probs_block = exp(logits_block - row_maxes_tile) / row_sums_tile
        temp_dV_tile += probs_block.T @ dO_tile
        dP_block = dO_tile @ V_tile.T
        D_tile = row_sum(dO_tile * O_tile)  # (q_block,)
        dS_block = probs_block * (dP_block - D_tile)
        dQ_tile += dS_block @ K_tile
        temp_dK_tile += dS_block.T @ Q_tile
    Write dK_tile, dV_tile

PyTorch implementation (tiled backward; assumes no dropout/mask for simplicity; extend as needed):

def flash_attention_backward(dO, Q, K, V, O, row_sums, row_maxes, mask=None):
    # Inputs: dO, Q, K, V, O (b, h, l, d); row_sums, row_maxes (b, h, l, 1)
    b, h, l, d = Q.shape
    device = Q.device
    scale = 1 / math.sqrt(d)
    
    dQ = torch.zeros_like(Q)
    dK = torch.zeros_like(K)
    dV = torch.zeros_like(V)
    
    q_block_size = min(BLOCK_SIZE, l)
    kv_block_size = BLOCK_SIZE
    
    # Splits (reuse forward logic)
    Q_blocks = torch.split(Q, q_block_size, dim=2)
    K_blocks = torch.split(K, kv_block_size, dim=2)
    V_blocks = torch.split(V, kv_block_size, dim=2)
    O_blocks = torch.split(O, q_block_size, dim=2)
    dO_blocks = torch.split(dO, q_block_size, dim=2)
    row_sums_blocks = torch.split(row_sums, q_block_size, dim=2)
    row_maxes_blocks = torch.split(row_maxes, q_block_size, dim=2)
    dQ_blocks = list(torch.split(dQ, q_block_size, dim=2))
    mask = torch.ones(b, l, device=device) if mask is None else mask
    mask_blocks = torch.split(mask, kv_block_size, dim=1)
    
    num_q_blocks, num_kv_blocks = len(Q_blocks), len(K_blocks)
    
    # Outer loop over KV tiles
    for kv_idx in range(num_kv_blocks):
        K_tile = K_blocks[kv_idx]
        V_tile = V_blocks[kv_idx]
        mask_tile = mask_blocks[kv_idx].unsqueeze(1).unsqueeze(1)
        
        temp_dK = torch.zeros_like(K_tile)
        temp_dV = torch.zeros_like(V_tile)
        
        # Inner loop over Q tiles
        for q_idx in range(num_q_blocks):
            Q_tile = Q_blocks[q_idx]
            O_tile = O_blocks[q_idx]
            dO_tile = dO_blocks[q_idx]
            row_sums_tile = row_sums_blocks[q_idx]
            row_maxes_tile = row_maxes_blocks[q_idx]
            
            # Recompute block probs
            Q_tile_scaled = Q_tile * scale
            logits_block = torch.matmul(Q_tile_scaled, K_tile.transpose(-2, -1))
            logits_block = torch.where(mask_tile > 0, logits_block, NEG_INF)
            probs_block = torch.exp(logits_block - row_maxes_tile) / (row_sums_tile + EPS)
            
            # dV accumulation
            temp_dV += torch.matmul(probs_block.transpose(-2, -1), dO_tile)
            
            # dP and D
            dP_block = torch.matmul(dO_tile, V_tile.transpose(-2, -1))
            D_tile = torch.sum(dO_tile * O_tile, dim=-1, keepdim=True)  # (b, h, q_block, 1)
            
            # dS
            dS_block = probs_block * (dP_block - D_tile)
            
            # dQ accumulation
            dQ_blocks[q_idx] += torch.matmul(dS_block, K_tile)
            
            # dK accumulation
            temp_dK += torch.matmul(dS_block.transpose(-2, -1), Q_tile)
        
        # Write gradients
        dK_blocks = torch.split(dK, kv_block_size, dim=2)
        dV_blocks = torch.split(dV, kv_block_size, dim=2)
        dK_blocks[kv_idx].copy_(temp_dK)
        dV_blocks[kv_idx].copy_(temp_dV)
    
    dQ = torch.cat(dQ_blocks, dim=2)
    dK = torch.cat(torch.split(dK, kv_block_size, dim=2), dim=2)  # Reassemble if needed
    dV = torch.cat(torch.split(dV, kv_block_size, dim=2), dim=2)
    return dQ, dK, dV

(Note: This is a simplified version without dropout; refer to the original paper for full details including dropout regeneration.)

To make it differentiable, wrap in a custom torch.autograd.Function.

Reference

License

MIT License

About

Flash Attention: Theory → Code → Optimization

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published