-
Notifications
You must be signed in to change notification settings - Fork 39
Description
Problem statement
- In the CUDA path, score scaling and bias addition typically happen outside the GEMM inner loop:
- Compute logits S = Q K^T,
- Scale S = s · S,
- Add bias S = S + B,
- Apply masks, then streaming softmax.
- This causes:
- Extra per-logit multiplies over an M×N tile (scale on S),
- An additional pass over the logits tile
- In backward, dQ/dK scaling is often done as separate epilogues (extra M·D or N·D multiplies), rather than absorbing s into the matmul operands where it’s cheaper and friendlier to the memory hierarchy.
Proposed solution
Forward: pre-scale Q and start the logits accumulator from bias; masks after matmul.
Let s = softmax_scale, M = seqlen_q, N = seqlen_k, D = headdim.
Baseline logits:
Proposed (scale Q):
Use GEMM in “accumulate” form D = A·B + C with C initialized from the bias tile:
Then apply masks to acc_s and proceed with the streaming softmax update. Bias is never scaled, and we avoid a per-logit multiply.
Backward: pre-scale K for both score recomputation and dQ, avoiding a dQ epilogue multiply.
With forward defined as above, the gradient wrt logits S is dS (from softmax/backprop). The parameter gradients are:
Proposed (scale K):
For dK, either apply a single epilogue multiply dK ← s·dK, or pre-scale the Q tile by s inside the dK path:
Choose the cheaper option by comparing M·D vs N·D.
Why this is effective
- Pre-scaling an operand replaces an M·N per-logit multiply (on S) with a single M·D (scale Q) or N·D (scale K) vector-scale — far fewer elements for typical attention shapes.
- Initializing the GEMM accumulator with bias leverages the D = A·B + C contract, avoids a separate “add bias” pass, and guarantees bias isn’t scaled.
- In backward, pre-scaling K naturally incorporates the s factor into dQ via the dot, eliminating an extra multiply on the M×D output; the dK path can be done with one multiply either inside the loop (scale Q tile) or as a single epilogue.
Alternatives considered
- Scale logits: S ← s·(QK^T), then add B. Simple, but costs M·N multiplies and risks scaling B if ordering slips.
- Scale accumulator post-GEMM: acc_s ← s·acc_s. Also M·N multiplies and forces another pass over the logits tile.
- Split scale across operands: Q ← √s·Q and K ← √s·K. Equivalent math, but doubles the number of vector scales unless carefully shared or cached.
- Epilogue scaling in bwd: multiply dQ and/or dK after matmul. Works, but adds M·D and/or N·D multiplies; pushing s into the reused operand tile is cheaper and better for cache.
Implementation details
- Forward (CUDA), file: flash_fwd_kernel.h
- Scale the chosen operand (Q preferred for decode; heuristic otherwise) once in registers/shared before the K-loop.
- Prefill the logits accumulator from the bias tile once (acc_s ← bias).
- Accumulate across K-tiles with cute::gemm(..., acc_s) so the bias participates as C in D = A·B + C.
- Apply masks to acc_s post-matmul; proceed with streaming softmax (m_i/lse_i) update as today.
- Backward (CUDA), files: csrc/flash_dmattn/src/bwd.h/cu
- Load K and pre-scale K ← s·K for both score recomputation and dQ:
- dQ path: dQ ← dS · K' (no epilogue scale).
- dK path: either pre-scale Q inside the loop (dK ← dS^T · (s·Q)) or scale the final dK tile once before store (dK ← s·dK).
- Keep bias-first accumulation and mask application identical to forward for stable numerics.
- Load K and pre-scale K ← s·K for both score recomputation and dQ:
- API/ABI: No public API changes. Padding, dtype, and head-dim constraints unchanged. Continue to rely on auto-padding to multiples of 8 for CUDA.
- Heuristic:
- Forward: scale the operand with fewer elements in the working set (Q for M ≪ N decode; otherwise pick min(M·D, N·D)).
- Backward: pick the side that avoids the larger epilogue (mirror the forward choice).
Use case
- Long-context inference and training (e.g., M,N ≥ 8k), where replacing M·N multiplies on logits with M·D or N·D vector scales yields tangible wins.
- Single-token decoding (M=1): scaling Q once is especially beneficial; bias-first accumulation reduces instruction count and improves stability.
- GQA/MQA workloads where bias and mask broadcasting are common.
Related work
- Our Triton backend already applies these ideas:
- Forward: Q ← s·Q; acc_s initialized from bias; masks post-matmul.
- Backward: K ← s·K; dQ via dS·K'; dK with a single multiply (inside loop or epilogue).
Additional context
- On A800, this delivered large forward gains vs SDPA (up to ~15× on long contexts) and up to ~2× in backward for long sequences, with exact numerical parity. We expect analogous benefits in the CUDA backend once implemented.