-
Notifications
You must be signed in to change notification settings - Fork 39
Refactor backward kernel for attention mask and bias support #55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Simplifies the backward pass implementation by removing support for local attention patterns and Additive Linear Bias (ALiBi) features. Updates copyright to include new contributor and removes unnecessary template parameters and conditional logic blocks that handled these specialized attention mechanisms. Streamlines the kernel to focus on core functionality without the complexity of windowed attention and positional bias computation.
Introduces shared memory layouts and copy operations for mask and bias tensors in the backward kernel configuration. Updates memory size calculations to account for the additional mask and bias storage requirements. Adds specialized copy atoms with 64-byte alignment for efficient mask and bias data transfers.
Updates variable and parameter names from "dzoh" (dZeroHold) to "dbias" throughout the Flash backward parameters structure to better reflect their actual purpose as bias gradients. Improves code readability and maintains consistency with naming conventions.
Introduces tensor definitions for attention mask, bias, and bias gradient to enable masked attention and bias computations in the backward kernel. Calculates proper memory offsets for mask and bias tensors based on batch, head, and block dimensions to ensure correct data access patterns during gradient computation.
Corrects the tensor reference used for creating the attention mask when not using dynamic masking, ensuring consistent device and shape alignment with the attention bias tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the backward attention kernel by removing local-attention and ALiBi support, and introducing explicit mask and bias tensor support throughout the backward pass.
- Adds shared-memory layouts, copy-atom definitions, and smem-size calculations for attention masks and biases
- Removes
Is_localandHas_alibitemplate parameters and associated ALiBi/local masking code - Renames backward parameters in
Flash_bwd_params(e.g.,dzoh_*→dbias_*) for clarity
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| csrc/src/kernel_traits.h | Define SmemLayoutMask/SmemLayoutBias, SmemCopyAtomBias, update kSmemSize to include mask & bias |
| csrc/src/flash_bwd_kernel.h | Drop ALiBi and local‐attention code, load gMask/gBias/gdBias, adjust template signatures |
| csrc/src/flash.h | Rename dzoh_* fields and pointers in Flash_bwd_params to dbias_* |
| benchmarks/benchmark_forward_equivalence.py | Change default mask init to use attn_bias |
Comments suppressed due to low confidence (1)
csrc/src/flash_bwd_kernel.h:138
- The loaded
gMask,gBias, andgdBiastensors are never used later in the kernel, and no bias is applied toscoresnor isgdBiaswritten back to global memory. You should integrate mask/bias into the score computation (e.g.,scores += gBias; scores = scores * gMask;) and add a final copy forgdBias, or remove these unused allocations.
Tensor gMask = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.attn_mask_ptr) + row_offset_mask),
| attn_bias = attn_bias.masked_fill(attn_mask == 0.0, min_dtype) | ||
| else: | ||
| attn_mask = torch.ones_like(attn_mask, dtype=dtype, device=attn_mask.device) | ||
| attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) |
Copilot
AI
Jul 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initializing attn_mask from attn_bias may yield an incorrect shape if the two tensors differ. Consider using torch.ones_like(attn_mask) or explicitly matching dimensions to ensure the mask has the intended shape.
| attn_mask = torch.ones_like(attn_bias, dtype=dtype, device=attn_bias.device) | |
| attn_mask = torch.ones( | |
| attn_bias.shape, dtype=dtype, device=attn_bias.device | |
| ) |
Remove local attention and ALiBi support to simplify the backward pass. Introduce mask and bias tensor support with updated memory layouts and efficient data transfer operations. Rename variables for clarity and improve code readability. Fix tensor references for consistent device and shape alignment.