-
Notifications
You must be signed in to change notification settings - Fork 39
Adds bias gradient computation to backward kernel #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Implements bias gradient calculation in the backward pass by adding a new column-block kernel that computes DBias alongside existing DQ, DK, and DV gradients. Updates initialization function to support multiple tensor names and modifies autotuning configurations to initialize both DQ and DBias tensors. Includes extensive masking logic and memory access patterns to handle various sequence length and head dimension configurations while maintaining numerical stability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Implements bias gradient computation in the Triton backward kernel, updates the tensor initialization helper to zero multiple tensors, and adjusts autotuning configurations to include the new DBias tensor.
- Introduces
_bwd_kernel_one_col_blockto computeDBiasalongside existing gradients. - Replaces the single-name
init_to_zerohelper with a version that accepts multiple tensor names. - Updates autotuning
pre_hookentries to zero bothDQandDBias.
Comments suppressed due to low confidence (2)
flash_dmattn/flash_dmattn_triton.py:339
- Consider adding a docstring above the
_bwd_kernel_one_col_blockTriton kernel to describe its purpose, parameters, and outputs, which will improve readability and onboarding for new contributors.
@triton.jit
flash_dmattn/flash_dmattn_triton.py:541
- The new bias gradient computation introduces both atomic and non-atomic code paths; add unit tests covering these scenarios (varying
ATOMIC_ADD, block sizes, and mask configurations) to ensure correctness across configurations.
dbias = (p * (dp - Di[:, None]))
| def init_to_zero(name): | ||
| return lambda nargs: nargs[name].zero_() | ||
| @triton.jit | ||
| def _bwd_kernel_one_col_block( |
Copilot
AI
Jul 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This kernel function spans 300+ lines and handles multiple responsibilities; consider breaking it into smaller helper functions (e.g., pointer setup, mask application, gradient computation) to improve readability and ease future maintenance.
| # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), | ||
| # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), | ||
| # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero(['DQ', 'DBias'])), |
Copilot
AI
Jul 7, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] There are several commented-out Triton configuration lines that include DBias; once DBias support is verified, consider removing or updating these stale lines to reduce clutter.
Updates the function alias to better reflect its Triton-based implementation, improving code readability and making the backend technology more explicit for developers.
Creates a more convenient function name that follows the module's naming convention and improves code readability
Implements bias gradient calculation in the backward pass by adding a new column-block kernel that computes DBias alongside existing DQ, DK, and DV gradients.
Updates initialization function to support multiple tensor names and modifies autotuning configurations to initialize both DQ and DBias tensors.
Includes extensive masking logic and memory access patterns to handle various sequence length and head dimension configurations while maintaining numerical stability.