-
Notifications
You must be signed in to change notification settings - Fork 39
[FEATURE SUPPORT] Robust dBias accumulation for seqlen_q_bias == 1 #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Simplifies backward parameters by removing an unused boolean that toggled bias accumulation. Reduces maintenance surface and potential ABI issues; no behavior change expected.
Replaces atomic updates in the bias-gradient path with on-chip accumulation and a post-loop row-sum reduction, reducing contention and improving performance and determinism. Derives the accumulation condition from layout (row stride == 0) and sequence length, drops auxiliary pointers/increments, and adds necessary synchronization to avoid shared-memory races when reusing buffers. Cleans up zeroing and copy ordering and consolidates the final write to global memory.
Simplifies bias-gradient handling by deriving accumulation from the bias sequence-length condition, removing the redundant parameter and related plumbing. Aligns zero-init of bias buffers with provided tensor options (no forced float), preventing mixed-precision dtype mismatches and improving correctness for MQA/GQA bias shapes. Streamlines the backward API with no intended behavior changes beyond dtype fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements robust dBias accumulation for the case when the query sequence dimension is broadcasted (seqlen_q_bias == 1), addressing numerical correctness issues in the backward pass. The kernel now automatically infers when to perform dBias accumulation based on the dbias_row_stride parameter, eliminating the need for an explicit flag in the API.
Key changes:
- In-kernel inference of dBias accumulation mode from
dbias_row_stride == 0instead of using an explicit API parameter - Block-wide reduction across the M dimension to correctly accumulate gradients into a single row when broadcasting
- Removal of the
accum_dbiasparameter from the API layer and struct definition
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| csrc/flash_dmattn/src/flash_bwd_kernel.h | Implements robust dBias accumulation with block-wide reduction for broadcasted query dimension; removes atomic accumulation path |
| csrc/flash_dmattn/src/flash.h | Removes deprecated accum_dbias field from Flash_bwd_params struct |
| csrc/flash_dmattn/flash_api.cpp | Removes accum_dbias parameter from function signatures and replaces its usage with direct seqlen_q_bias == 1 checks |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts) | ||
| : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) |
Copilot
AI
Oct 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dtype for dbias_expanded allocation has changed from at::kFloat to the default opts dtype. This may cause precision loss if opts is not Float32, since the accumulation in the kernel uses ElementAccum (fp32). Consider explicitly preserving opts.dtype(at::kFloat) to match the kernel's accumulation precision.
| ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts) | |
| : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts) | |
| ? torch::zeros({batch_size, num_heads, 1, seqlen_k_rounded}, opts.dtype(at::kFloat)) | |
| : torch::zeros({batch_size, num_heads, seqlen_q, seqlen_k_rounded}, opts.dtype(at::kFloat)) |
|
|
||
| cute::copy(smem_tiled_copy_PdS, tdBiasadBias, tdSsdS); | ||
|
|
||
| __syncthreads(); |
Copilot
AI
Oct 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reduction loop should document that it's performing a column-wise sum across all M rows into row 0 of shared memory, since this is the critical step for correct dBias accumulation when dbias_row_stride == 0.
| __syncthreads(); | |
| __syncthreads(); | |
| // Perform a column-wise sum across all M rows into row 0 of shared memory. | |
| // This is the critical step for correct dBias accumulation when dbias_row_stride == 0. |
| gmem_tiled_copy_dBias, | ||
| tBiassBias, tdBiasgdBias, | ||
| tBiascBias, tBiaspBias, | ||
| /*max_M=*/1 |
Copilot
AI
Oct 22, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The max_M=1 parameter is critical to the correctness of this implementation as it ensures only the reduced row is written. This should have a comment explaining that we're writing only row 0 which contains the sum across all M rows.
Summary
Design
Alternatives considered:
Changes
Implementation Notes
Tests
Docs
Checklist