Skip to content

[FEATURE REQUEST] fuse softmax scale into operands and bias-first accumulation #202

@LoserCheems

Description

@LoserCheems

Problem statement

  • In the CUDA path, score scaling and bias addition typically happen outside the GEMM inner loop:
    • Compute logits S = Q K^T,
    • Scale S = s · S,
    • Add bias S = S + B,
    • Apply masks, then streaming softmax.
  • This causes:
    • Extra per-logit multiplies over an M×N tile (scale on S),
    • An additional pass over the logits tile
  • In backward, dQ/dK scaling is often done as separate epilogues (extra M·D or N·D multiplies), rather than absorbing s into the matmul operands where it’s cheaper and friendlier to the memory hierarchy.

Proposed solution

Forward: pre-scale Q and start the logits accumulator from bias; masks after matmul.
Let s = softmax_scale, M = seqlen_q, N = seqlen_k, D = headdim.
Baseline logits:

$$ S = s \cdot (Q K^\top) + B $$

Proposed (scale Q):

$$ Q' = s \cdot Q,\quad S = Q' K^\top + B $$

Use GEMM in “accumulate” form D = A·B + C with C initialized from the bias tile:

$$ acc_s \leftarrow B \quad\text{(once, before K-loop)} \qquad acc_s \mathrel{+}= Q' K^\top \quad\text{(across K-tiles)} $$

Then apply masks to acc_s and proceed with the streaming softmax update. Bias is never scaled, and we avoid a per-logit multiply.

Backward: pre-scale K for both score recomputation and dQ, avoiding a dQ epilogue multiply.
With forward defined as above, the gradient wrt logits S is dS (from softmax/backprop). The parameter gradients are:

$$ dQ = s \cdot (dS , K),\qquad dK = s \cdot (dS^\top Q),\qquad dB = \sum_{m,n} dS_{m,n} $$

Proposed (scale K):

$$ K' = s \cdot K,\quad dQ = dS K' \quad\text{(no extra epilogue scale)} $$

For dK, either apply a single epilogue multiply dK ← s·dK, or pre-scale the Q tile by s inside the dK path:

$$ dK = dS^\top (s \cdot Q) $$

Choose the cheaper option by comparing M·D vs N·D.

Why this is effective

  • Pre-scaling an operand replaces an M·N per-logit multiply (on S) with a single M·D (scale Q) or N·D (scale K) vector-scale — far fewer elements for typical attention shapes.
  • Initializing the GEMM accumulator with bias leverages the D = A·B + C contract, avoids a separate “add bias” pass, and guarantees bias isn’t scaled.
  • In backward, pre-scaling K naturally incorporates the s factor into dQ via the dot, eliminating an extra multiply on the M×D output; the dK path can be done with one multiply either inside the loop (scale Q tile) or as a single epilogue.

Alternatives considered

  • Scale logits: S ← s·(QK^T), then add B. Simple, but costs M·N multiplies and risks scaling B if ordering slips.
  • Scale accumulator post-GEMM: acc_s ← s·acc_s. Also M·N multiplies and forces another pass over the logits tile.
  • Split scale across operands: Q ← √s·Q and K ← √s·K. Equivalent math, but doubles the number of vector scales unless carefully shared or cached.
  • Epilogue scaling in bwd: multiply dQ and/or dK after matmul. Works, but adds M·D and/or N·D multiplies; pushing s into the reused operand tile is cheaper and better for cache.

Implementation details

  • Forward (CUDA), file: flash_fwd_kernel.h
    • Scale the chosen operand (Q preferred for decode; heuristic otherwise) once in registers/shared before the K-loop.
    • Prefill the logits accumulator from the bias tile once (acc_s ← bias).
    • Accumulate across K-tiles with cute::gemm(..., acc_s) so the bias participates as C in D = A·B + C.
    • Apply masks to acc_s post-matmul; proceed with streaming softmax (m_i/lse_i) update as today.
  • Backward (CUDA), files: csrc/flash_dmattn/src/bwd.h/cu
    • Load K and pre-scale K ← s·K for both score recomputation and dQ:
      • dQ path: dQ ← dS · K' (no epilogue scale).
    • dK path: either pre-scale Q inside the loop (dK ← dS^T · (s·Q)) or scale the final dK tile once before store (dK ← s·dK).
    • Keep bias-first accumulation and mask application identical to forward for stable numerics.
  • API/ABI: No public API changes. Padding, dtype, and head-dim constraints unchanged. Continue to rely on auto-padding to multiples of 8 for CUDA.
  • Heuristic:
    • Forward: scale the operand with fewer elements in the working set (Q for M ≪ N decode; otherwise pick min(M·D, N·D)).
    • Backward: pick the side that avoids the larger epilogue (mirror the forward choice).

Use case

  • Long-context inference and training (e.g., M,N ≥ 8k), where replacing M·N multiplies on logits with M·D or N·D vector scales yields tangible wins.
  • Single-token decoding (M=1): scaling Q once is especially beneficial; bias-first accumulation reduces instruction count and improves stability.
  • GQA/MQA workloads where bias and mask broadcasting are common.

Related work

  • Our Triton backend already applies these ideas:
    • Forward: Q ← s·Q; acc_s initialized from bias; masks post-matmul.
    • Backward: K ← s·K; dQ via dS·K'; dK with a single multiply (inside loop or epilogue).

Additional context

  • On A800, this delivered large forward gains vs SDPA (up to ~15× on long contexts) and up to ~2× in backward for long sequences, with exact numerical parity. We expect analogous benefits in the CUDA backend once implemented.

Metadata

Metadata

Labels

featureNew feature request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions