-
Notifications
You must be signed in to change notification settings - Fork 39
[FEATURE SUPPORT] Broadcastable 4D mask/bias, 128‑rounded key length, stride‑0 broadcasting, and dbias reductions #190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Sets zero stride for size-1 dims to broadcast across batch/head/row, instead of using the underlying stride. Prevents incorrect indexing when mask/bias are shared across dimensions and aligns with standard broadcasting semantics. Improves forward-pass robustness with partially broadcast inputs.
Sets zero stride for singleton batch/head/row dims to align with broadcasting semantics, preventing misaddressing when bias is shared across dimensions. Improves correctness and flexibility of the backward path with broadcasted bias.
Clarifies and validates mask/bias to a strict 4D contract with broadcastable batch, head, and query dims, and a key length rounded to 128. Removes implicit 3D unsqueeze/expand to prevent silent shape mismatches and stride issues. Reworks shape handling in the swapped-heads path to use reshape-based broadcasting and preserves original batch/head counts for correctness. Computes rounding earlier and applies consistent checks across inputs. Improves correctness and stability by aligning inputs with kernel expectations and reducing accidental expansions.
Standardizes attention aux inputs to strict 4D shapes with contiguous last dim and explicit broadcasting over batch, heads, and seqlen_q. Removes 3D mask/bias handling and validates dimensions against rounded key length. Allocates/validates dbias with broadcast-aware shapes and updates reductions to correctly sum over group, batch, and seqlen_q when broadcast, improving correctness for MQA/GQA and padded key lengths. Improves shape checks and internal consistency to prevent silent misalignment and shape-induced bugs.
Updates attention mask/bias handling to round the key length to a multiple of 128 and expand length-1 tensors or pad as needed, preventing shape mismatches and reducing unnecessary padding of K/V. Simplifies backward by slicing only the bias gradient to the original key length and removing tracking of the original sequence length. Clarifies docs to allow broadcastable dimensions for mask/bias across batch, heads, and sequence.
Corrects tensor shape examples to use 1-sized broadcastable dims for batch, heads, query, and key, improving clarity and avoiding invalid 0-length notation. Removes “dbias” jargon from the gradient description for clearer wording. Syncs English and Chinese documentation on shape semantics.
Simplifies dynamic masking by accepting precomputed attention bias and an optional causal mask, removing dependence on internal ZOH/dt projection parameters and unifying the API across Python, CUDA, Triton, and Flex backends. Applies masking explicitly via a boolean mask with -inf before softmax and selects a top-k window per query (optionally respecting the causal mask), improving correctness and consistency across implementations. Aligns function signatures, renames keep_window_size to window_size, removes unused return flags, and fixes tensor layouts/contiguity where needed. Updates tests to generate attention bias and derive causal masks, improving forward-equivalence coverage and determinism while reducing coupling to value-state-derived features.
Refactors attention paths to accept external attention bias and boolean causal mask, replacing zoh/dt-based masking and cache-position logic. Introduces a generic mask preparer that applies top-k windowing (optionally causal-aware), and standardizes interfaces across SDPA, Flash, Triton, and Flex implementations. Removes zoh/dt projection and related params, repeats KV artifacts for GQA, and consistently applies additive masks. Updates benchmarks to generate bias/mask inputs, rename keep_window_size to window_size, adjust head dims, and harmonize result handling and output labeling. Improves API consistency, simplifies experimentation with custom biases, and aligns masking semantics across kernels for more reliable benchmarking.
Replaces internal zoh/value-proj masking with an external attention bias plus optional causal mask and top‑k windowing, simplifying the interface and masking semantics across backends. Aligns Python, CUDA, Triton, and Flex to a shared signature, applies masking consistently, ensures contiguous layouts, and uses deterministic execution for stable gradients. Expands backward‑equivalence coverage to head dims 192/256 and updates tests to use bf16 bias and causal masks, improving reproducibility and backend parity.
Reworks dynamic masking to consume precomputed attention bias plus optional boolean causal mask and a window size, using top‑k selection within the window and honoring causality. Removes ZOH/dt_proj/A dependency to simplify masking and reduce coupling. Aligns CUDA, Triton, Flex, and SDPA wrapper to a unified interface, adds GQA support via KV repetition, and ensures consistent tensor layout. Detaches top‑k selection to avoid unintended gradients. Updates benchmarks to generate attention bias and boolean causal masks, renames keep_window_size to window_size, and adjusts configs/loops accordingly for consistent evaluation across backends. Improves clarity, consistency, and extensibility of the attention backward benchmarks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces broadcastable 4D mask/bias with strict shape contracts, 128-rounded key length for kernel access alignment, stride-0 broadcasting to avoid materialization, and corrected dbias reduction semantics. It resolves issue #189 by implementing robust attention mask/bias support with proper broadcasting and memory efficiency optimizations.
- Implements strict 4D shape contract for mask/bias with broadcast support over batch, heads, and query dimensions
- Rounds key length to 128 multiples for kernel access patterns, preventing out-of-bounds access and NaNs
- Uses stride-0 broadcasting in CUDA to avoid memory expansion along broadcast dimensions
Reviewed Changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_interface.py | Implements 128-rounding for mask/bias K-dim and backward gradient slicing |
| csrc/flash_dmattn/flash_api.cpp | Adds stride-0 broadcasting, 4D shape validation, and corrected swapped path handling |
| benchmarks/forward_performance.py | Updates test functions to use simplified mask preparation with proper broadcasting |
| benchmarks/forward_equivalence.py | Refactors test code to remove complex zoh_states calculations and use direct bias inputs |
| benchmarks/backward_performance.py | Simplifies backward test functions with direct bias inputs and proper mask handling |
| benchmarks/backward_equivalence.py | Updates backward equivalence tests to use simplified bias approach |
| README_zh.md | Updates Chinese documentation to reflect new broadcastable 4D shape requirements |
| README.md | Updates English documentation to reflect new broadcastable 4D shape requirements |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| const int num_heads_k = k.size(2); | ||
| int num_heads_mask = has_mask ? mask.size(1) : 1; | ||
| int num_heads_bias = has_bias ? bias.size(1) : 1; | ||
| auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; |
Copilot
AI
Oct 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The round_multiple lambda function is defined but there's already a global function with the same name and logic. Consider using the existing global function to avoid code duplication.
| auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
Comments suppressed due to low confidence (1)
csrc/flash_dmattn/flash_api.cpp:96
- The multiplication by seqlen_q for mask and bias batch strides is missing when seqlenq_ngroups_swapped is true, unlike the q_batch_stride which is correctly updated on line 94. This could lead to incorrect memory access patterns in the swapped path.
params.mask_batch_stride *= seqlen_q;
params.bias_batch_stride *= seqlen_q;
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| (1, 2, 1, 4096, 4096, 128, True), | ||
| (1, 2, 1, 4096, 4096, 128, False), | ||
|
|
||
| # Not support head_dim > 128 in triton yet |
Copilot
AI
Oct 12, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Commented out test cases should either be removed if no longer needed or have a comment explaining why they are disabled and when they might be re-enabled.
| # Not support head_dim > 128 in triton yet | |
| # The following test cases are disabled because Triton does not currently support head_dim > 128. | |
| # Once Triton adds support for head_dim > 128, these test cases should be re-enabled. |
Summary
Design
Changes
Implementation Notes
Tests
Docs
Checklist