Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • Implements a robust contract for attention mask/bias with strict 4D shapes and broadcast over batch, heads, and query dims.
  • Rounds key length to a multiple of 128 for mask/bias to match kernel access, preventing OOB and NaNs.
  • Uses stride=0 broadcasting in CUDA for size‑1 dims to avoid materialization along B/H/S_q.
  • Fixes swapped seqlen_q==1 path and aligns dbias reductions with broadcast semantics.
  • Resolves issue [FEATURE REQUEST] Support {num_heads, num_kv_heads, 1} shaped bias in attention functions #189.

Design

  • Contract for aux tensors:
    • Shapes: ({batch_size|1}, {num_heads|num_kv_heads|1}, {seqlen_q|1}, round_multiple(seqlen_k, 128)).
    • Boolean mask (True=keep), floating bias (additive).
    • Last dimension must be contiguous; only K-dim is materialized/padded to 128 alignment. Q/K/V remain unpadded; no tail copies for K/V to reduce memory overhead.
  • Broadcasting:
    • CUDA sets zero strides for singleton B/H/S_q to broadcast without expanding.
    • Only the last K-dim is expanded/padded (S_k==1 → expand; S_k==seqlen_k → pad to 128).
  • Swapped path (seqlen_q==1 and H>H_k):
    • Reshape heads to (H_k, ngroups) via views; preserve original mask/bias batch and head semantics.
    • Pass seqlenq_ngroups_swapped into params so batch/head strides and LSE writeback are correct.

Changes

  • Python (flash_dmattn_interface.py):
    • Pads/expands mask/bias last dim to seqlen_k_rounded=round_multiple(seqlen_k, 128).
    • Enforces 4D mask/bias; avoids materializing broadcast dims; keeps last dim contiguous.
    • Backward slices returned dbias to the true seqlen_k so callers see trimmed gradients.
  • CUDA C++ (csrc/flash_dmattn/flash_api.cpp):
    • set_params_fprop/dgrad:
      • mask/bias batch/head/row strides set to 0 for size‑1 dims (broadcast-safe).
      • Batch stride handling honors seqlenq_ngroups_swapped.
    • mha_fwd:
      • Enforces 4D mask/bias with K-dim = seqlen_k_rounded and contiguous last dim.
      • Validates B∈{B,1}, H∈{1,H_k,H}, S_q∈{1,L_q}; no implicit 3D unsqueeze/expand.
      • Swapped path uses reshape-only views for mask/bias and restores original views after compute.
    • mha_bwd:
      • Enforces the same 4D contract; allocates/accepts dbias with K-dim rounded to 128.
      • Correctly reduces dbias over groups, batch, and S_q when broadcast, then copies back into user-provided dbias.
  • Documentation:
    • README/README_zh updated to reflect broadcastable 4D shapes and removal of “dbias” jargon in prose.

Implementation Notes

  • No physical expand on B/H/S_q; broadcasting via stride=0 avoids large temporaries and keeps ABI stable.
  • Only the K-dim is materialized/rounded, which the kernels iterate over; last dim is enforced contiguous.
  • Swapped path:
    • Save original sizes for mask/bias, reshape to (H_k, ngroups) forms during compute, and restore after.
    • Pass seqlenq_ngroups_swapped into params; missing this led to NaNs in multi-batch/multi-head settings.
  • Stability fix:
    • With mask/bias present, kernels iterate N blocks against rounded length. Aligning mask/bias K-dim to 128 prevents OOB; K/V are not padded to reduce memory.

Tests

  • Benchmarks:
    • forward_equivalence.py: Passed for various B, H in {1,H_k,H}, S_q in {1,L_q}, causal/non‑causal, bf16/fp16.
    • backward_equivalence.py: dQ/dK/dV matched; dbias now non‑zero and matches reference; gradients sliced back to original seqlen_k on return.
  • Regressions covered:
    • seqlen_q==1 swapped path with H>H_k across B>1.
    • mask/bias None vs provided; broadcast over batch/head/S_q dims.

Docs

  • Updated shape conventions in README and README_zh:
    • Now documented as ({1|batch_size}, {1|num_kv_heads|num_heads}, {1|query_len}, {1|key_len}) with K rounded to 128.
    • Clarified broadcast semantics and gradient wording.

Checklist

Sets zero stride for size-1 dims to broadcast across batch/head/row, instead of using the underlying stride.
Prevents incorrect indexing when mask/bias are shared across dimensions and aligns with standard broadcasting semantics.
Improves forward-pass robustness with partially broadcast inputs.
Sets zero stride for singleton batch/head/row dims to align with broadcasting semantics, preventing misaddressing when bias is shared across dimensions.

Improves correctness and flexibility of the backward path with broadcasted bias.
Clarifies and validates mask/bias to a strict 4D contract with broadcastable batch, head, and query dims, and a key length rounded to 128. Removes implicit 3D unsqueeze/expand to prevent silent shape mismatches and stride issues.

Reworks shape handling in the swapped-heads path to use reshape-based broadcasting and preserves original batch/head counts for correctness. Computes rounding earlier and applies consistent checks across inputs.

Improves correctness and stability by aligning inputs with kernel expectations and reducing accidental expansions.
Standardizes attention aux inputs to strict 4D shapes with contiguous last dim and explicit broadcasting over batch, heads, and seqlen_q. Removes 3D mask/bias handling and validates dimensions against rounded key length.

Allocates/validates dbias with broadcast-aware shapes and updates reductions to correctly sum over group, batch, and seqlen_q when broadcast, improving correctness for MQA/GQA and padded key lengths.

Improves shape checks and internal consistency to prevent silent misalignment and shape-induced bugs.
Updates attention mask/bias handling to round the key length to a multiple of 128 and expand length-1 tensors or pad as needed, preventing shape mismatches and reducing unnecessary padding of K/V.

Simplifies backward by slicing only the bias gradient to the original key length and removing tracking of the original sequence length.

Clarifies docs to allow broadcastable dimensions for mask/bias across batch, heads, and sequence.
Corrects tensor shape examples to use 1-sized broadcastable dims for batch, heads, query, and key, improving clarity and avoiding invalid 0-length notation.
Removes “dbias” jargon from the gradient description for clearer wording.
Syncs English and Chinese documentation on shape semantics.
Simplifies dynamic masking by accepting precomputed attention bias and an optional causal mask, removing dependence on internal ZOH/dt projection parameters and unifying the API across Python, CUDA, Triton, and Flex backends.

Applies masking explicitly via a boolean mask with -inf before softmax and selects a top-k window per query (optionally respecting the causal mask), improving correctness and consistency across implementations.

Aligns function signatures, renames keep_window_size to window_size, removes unused return flags, and fixes tensor layouts/contiguity where needed. Updates tests to generate attention bias and derive causal masks, improving forward-equivalence coverage and determinism while reducing coupling to value-state-derived features.
Refactors attention paths to accept external attention bias and boolean causal mask, replacing zoh/dt-based masking and cache-position logic. Introduces a generic mask preparer that applies top-k windowing (optionally causal-aware), and standardizes interfaces across SDPA, Flash, Triton, and Flex implementations.

Removes zoh/dt projection and related params, repeats KV artifacts for GQA, and consistently applies additive masks. Updates benchmarks to generate bias/mask inputs, rename keep_window_size to window_size, adjust head dims, and harmonize result handling and output labeling.

Improves API consistency, simplifies experimentation with custom biases, and aligns masking semantics across kernels for more reliable benchmarking.
Replaces internal zoh/value-proj masking with an external attention bias plus optional causal mask and top‑k windowing, simplifying the interface and masking semantics across backends.

Aligns Python, CUDA, Triton, and Flex to a shared signature, applies masking consistently, ensures contiguous layouts, and uses deterministic execution for stable gradients.

Expands backward‑equivalence coverage to head dims 192/256 and updates tests to use bf16 bias and causal masks, improving reproducibility and backend parity.
Reworks dynamic masking to consume precomputed attention bias plus optional boolean causal mask and a window size, using top‑k selection within the window and honoring causality. Removes ZOH/dt_proj/A dependency to simplify masking and reduce coupling.

Aligns CUDA, Triton, Flex, and SDPA wrapper to a unified interface, adds GQA support via KV repetition, and ensures consistent tensor layout. Detaches top‑k selection to avoid unintended gradients.

Updates benchmarks to generate attention bias and boolean causal masks, renames keep_window_size to window_size, and adjusts configs/loops accordingly for consistent evaluation across backends.

Improves clarity, consistency, and extensibility of the attention backward benchmarks.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces broadcastable 4D mask/bias with strict shape contracts, 128-rounded key length for kernel access alignment, stride-0 broadcasting to avoid materialization, and corrected dbias reduction semantics. It resolves issue #189 by implementing robust attention mask/bias support with proper broadcasting and memory efficiency optimizations.

  • Implements strict 4D shape contract for mask/bias with broadcast support over batch, heads, and query dimensions
  • Rounds key length to 128 multiples for kernel access patterns, preventing out-of-bounds access and NaNs
  • Uses stride-0 broadcasting in CUDA to avoid memory expansion along broadcast dimensions

Reviewed Changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
flash_dmattn/flash_dmattn_interface.py Implements 128-rounding for mask/bias K-dim and backward gradient slicing
csrc/flash_dmattn/flash_api.cpp Adds stride-0 broadcasting, 4D shape validation, and corrected swapped path handling
benchmarks/forward_performance.py Updates test functions to use simplified mask preparation with proper broadcasting
benchmarks/forward_equivalence.py Refactors test code to remove complex zoh_states calculations and use direct bias inputs
benchmarks/backward_performance.py Simplifies backward test functions with direct bias inputs and proper mask handling
benchmarks/backward_equivalence.py Updates backward equivalence tests to use simplified bias approach
README_zh.md Updates Chinese documentation to reflect new broadcastable 4D shape requirements
README.md Updates English documentation to reflect new broadcastable 4D shape requirements

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

const int num_heads_k = k.size(2);
int num_heads_mask = has_mask ? mask.size(1) : 1;
int num_heads_bias = has_bias ? bias.size(1) : 1;
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The round_multiple lambda function is defined but there's already a global function with the same name and logic. Consider using the existing global function to avoid code duplication.

Suggested change
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };

Copilot uses AI. Check for mistakes.
LoserCheems and others added 2 commits October 12, 2025 19:18
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@LoserCheems LoserCheems requested a review from Copilot October 12, 2025 11:19
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.

Comments suppressed due to low confidence (1)

csrc/flash_dmattn/flash_api.cpp:96

  • The multiplication by seqlen_q for mask and bias batch strides is missing when seqlenq_ngroups_swapped is true, unlike the q_batch_stride which is correctly updated on line 94. This could lead to incorrect memory access patterns in the swapped path.
            params.mask_batch_stride *= seqlen_q;
            params.bias_batch_stride *= seqlen_q;

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

(1, 2, 1, 4096, 4096, 128, True),
(1, 2, 1, 4096, 4096, 128, False),

# Not support head_dim > 128 in triton yet
Copy link

Copilot AI Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Commented out test cases should either be removed if no longer needed or have a comment explaining why they are disabled and when they might be re-enabled.

Suggested change
# Not support head_dim > 128 in triton yet
# The following test cases are disabled because Triton does not currently support head_dim > 128.
# Once Triton adds support for head_dim > 128, these test cases should be re-enabled.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 464baf7 into main Oct 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants