-
Notifications
You must be signed in to change notification settings - Fork 39
Enhance Flash Attention with required parameters and improved backward pass #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Ensures mask and bias tensors are always properly initialized when not provided by the caller. Converts boolean masks to float tensors and creates default all-ones mask when none is specified. Initializes zero bias tensor when bias parameter is None. Updates contiguity check to include mask and bias tensors for consistent memory layout.
Uncomments and updates the backward method to support gradient computation. Adds bias gradient computation with dbias tensor allocation and updates error message to reflect FlashDynamicMaskAttention functionality. Reorders function parameters to include mask and bias in the backward call.
Implements the backward propagation function for the flash attention mechanism, enabling gradient computation through the attention layers. The backward pass handles gradient computation for queries, keys, values, bias, and mask tensors with proper memory layout validation and stride checking. Includes preprocessing step for output gradients and uses Triton kernels for efficient backward computation with support for causal masking and custom softmax scaling.
Removes optional default values for mask and bias parameters to enforce explicit passing of these tensors. Eliminates automatic creation of default all-ones mask and zero bias tensors, requiring callers to provide these inputs explicitly. Simplifies parameter validation logic by removing conditional null checks.
Implements the backward pass computation with Triton autotuning support. Includes configuration options for block sizes and sequence parallelism with optimized settings for different scenarios. The kernel supports both sequential and parallel execution modes based on the SEQUENCE_PARALLEL flag. Adds proper memory stride handling and atomic operations for gradient accumulation in parallel mode.
Introduces a helper function that returns a lambda for zeroing tensors by name. This utility simplifies tensor initialization patterns in the codebase by providing a reusable function that can zero out named tensor arguments.
Implements a new Triton kernel for the backward pass preprocessing step that computes the dot product between output and output gradients. This kernel calculates the delta values needed for the backward pass by performing element-wise multiplication and reduction along the head dimension, which is a common operation in attention mechanism gradients.
Introduces a dedicated function to handle storing dk and dv gradients with proper masking logic to prevent race conditions. The function handles different combinations of EVEN_M, EVEN_N, and EVEN_HEADDIM flags to apply appropriate masks during tensor stores, addressing a known race condition bug when certain dimension conditions are met.
Ensures mask and bias tensors match the dtype of query tensor when created as defaults. Prevents potential type mismatches that could cause runtime errors or unexpected behavior in attention computations.
Standardizes parameter naming convention by shortening bias-related stride variable names from `stride_dbias*` to `stride_db*` format. Improves code readability and maintains consistency with other stride parameter naming patterns throughout the kernel function.
Eliminates runtime conditional checks that prevented Triton from properly optimizing control flow at compile time. Moves key and value loading outside of conditional blocks to ensure consistent execution paths, which allows the compiler to make better optimization decisions. Removes the any_active check that was causing dynamic branching issues and simplifies the masking logic to use compile-time determinable conditions.
Fixes the backward function to properly return the bias gradient instead of None, ensuring gradient computation flows correctly through the bias parameter during backpropagation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enhances the Flash Attention implementation by requiring explicit mask and bias inputs, adding a full backward pass via new Triton kernels, and improving performance with optimized autotune configurations.
- Make
maskandbiasrequired parameters and remove default creation in core kernel. - Introduce
_bwd_preprocess_do_o_dot,_bwd_store_dk_dv, and_flash_attn_backwardto compute and store gradients. - Add utility
init_to_zerohook and updated Triton autotune & heuristic configs.
Comments suppressed due to low confidence (3)
flash_dmattn/flash_dmattn_triton.py:263
- This new Triton JIT function lacks a docstring explaining its purpose, inputs, and output layout. Adding a brief description would improve readability and maintainability.
def _bwd_preprocess_do_o_dot(
flash_dmattn/flash_dmattn_triton.py:340
- [nitpick] The parameter name
nargsis unclear in theinit_to_zerolambda. Renaming it to something likekwargsorargs_dictwould clarify its role.
return lambda nargs: nargs[name].zero_()
flash_dmattn/flash_dmattn_triton.py:614
- A complete backward implementation was added but there are no corresponding unit tests for gradient correctness. Consider adding tests to validate gradients against a reference implementation.
def _flash_attn_backward(
flash_dmattn/flash_dmattn_triton.py
Outdated
| @staticmethod | ||
| def backward(ctx, do): | ||
| q, k, v, o, lse, mask, bias = ctx.saved_tensors | ||
| assert not ctx.needs_input_grad[3], "FlashDynamicMaskAttention does not support mask gradient yet" |
Copilot
AI
Jul 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message refers to "FlashDynamicMaskAttention", but the class is FlashAttnFunc. Update the message to accurately reference the function or class name.
| assert not ctx.needs_input_grad[3], "FlashDynamicMaskAttention does not support mask gradient yet" | |
| assert not ctx.needs_input_grad[3], "FlashAttnFunc does not support mask gradient yet" |
Updates class name from FlashAttnFunc to FlashDMAttnFunc for consistency with the flash_dmattn_func function name and module purpose. Also updates corresponding error message text to use the shortened "FlashDMAttn" naming convention.
Implement required mask and bias parameters in Flash Attention, ensuring explicit input from callers. Enable backward pass with gradient computation for all relevant tensors, optimizing memory layout and performance with Triton kernels. Introduce utility functions for tensor initialization and safe gradient storage, addressing race conditions and dtype consistency.