Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Implement required mask and bias parameters in Flash Attention, ensuring explicit input from callers. Enable backward pass with gradient computation for all relevant tensors, optimizing memory layout and performance with Triton kernels. Introduce utility functions for tensor initialization and safe gradient storage, addressing race conditions and dtype consistency.

Ensures mask and bias tensors are always properly initialized when not provided by the caller.

Converts boolean masks to float tensors and creates default all-ones mask when none is specified.
Initializes zero bias tensor when bias parameter is None.

Updates contiguity check to include mask and bias tensors for consistent memory layout.
Uncomments and updates the backward method to support gradient computation.

Adds bias gradient computation with dbias tensor allocation and updates
error message to reflect FlashDynamicMaskAttention functionality.

Reorders function parameters to include mask and bias in the backward call.
Implements the backward propagation function for the flash attention mechanism, enabling gradient computation through the attention layers.

The backward pass handles gradient computation for queries, keys, values, bias, and mask tensors with proper memory layout validation and stride checking.

Includes preprocessing step for output gradients and uses Triton kernels for efficient backward computation with support for causal masking and custom softmax scaling.
Removes optional default values for mask and bias parameters to enforce explicit passing of these tensors.

Eliminates automatic creation of default all-ones mask and zero bias tensors, requiring callers to provide these inputs explicitly.

Simplifies parameter validation logic by removing conditional null checks.
Implements the backward pass computation with Triton autotuning support.

Includes configuration options for block sizes and sequence parallelism with optimized settings for different scenarios. The kernel supports both sequential and parallel execution modes based on the SEQUENCE_PARALLEL flag.

Adds proper memory stride handling and atomic operations for gradient accumulation in parallel mode.
Introduces a helper function that returns a lambda for zeroing tensors by name.

This utility simplifies tensor initialization patterns in the codebase by providing a reusable function that can zero out named tensor arguments.
Implements a new Triton kernel for the backward pass preprocessing step that computes the dot product between output and output gradients.

This kernel calculates the delta values needed for the backward pass by performing element-wise multiplication and reduction along the head dimension, which is a common operation in attention mechanism gradients.
Introduces a dedicated function to handle storing dk and dv gradients with proper masking logic to prevent race conditions.

The function handles different combinations of EVEN_M, EVEN_N, and EVEN_HEADDIM flags to apply appropriate masks during tensor stores, addressing a known race condition bug when certain dimension conditions are met.
Ensures mask and bias tensors match the dtype of query tensor when created as defaults.

Prevents potential type mismatches that could cause runtime errors or unexpected behavior in attention computations.
Standardizes parameter naming convention by shortening bias-related stride variable names from `stride_dbias*` to `stride_db*` format.

Improves code readability and maintains consistency with other stride parameter naming patterns throughout the kernel function.
Eliminates runtime conditional checks that prevented Triton from properly optimizing control flow at compile time.

Moves key and value loading outside of conditional blocks to ensure consistent execution paths, which allows the compiler to make better optimization decisions.

Removes the any_active check that was causing dynamic branching issues and simplifies the masking logic to use compile-time determinable conditions.

This comment was marked as outdated.

Fixes the backward function to properly return the bias gradient instead of None, ensuring gradient computation flows correctly through the bias parameter during backpropagation.
@LoserCheems LoserCheems requested a review from Copilot July 5, 2025 13:11
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances the Flash Attention implementation by requiring explicit mask and bias inputs, adding a full backward pass via new Triton kernels, and improving performance with optimized autotune configurations.

  • Make mask and bias required parameters and remove default creation in core kernel.
  • Introduce _bwd_preprocess_do_o_dot, _bwd_store_dk_dv, and _flash_attn_backward to compute and store gradients.
  • Add utility init_to_zero hook and updated Triton autotune & heuristic configs.
Comments suppressed due to low confidence (3)

flash_dmattn/flash_dmattn_triton.py:263

  • This new Triton JIT function lacks a docstring explaining its purpose, inputs, and output layout. Adding a brief description would improve readability and maintainability.
def _bwd_preprocess_do_o_dot(

flash_dmattn/flash_dmattn_triton.py:340

  • [nitpick] The parameter name nargs is unclear in the init_to_zero lambda. Renaming it to something like kwargs or args_dict would clarify its role.
    return lambda nargs: nargs[name].zero_()

flash_dmattn/flash_dmattn_triton.py:614

  • A complete backward implementation was added but there are no corresponding unit tests for gradient correctness. Consider adding tests to validate gradients against a reference implementation.
def _flash_attn_backward(

@staticmethod
def backward(ctx, do):
q, k, v, o, lse, mask, bias = ctx.saved_tensors
assert not ctx.needs_input_grad[3], "FlashDynamicMaskAttention does not support mask gradient yet"
Copy link

Copilot AI Jul 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message refers to "FlashDynamicMaskAttention", but the class is FlashAttnFunc. Update the message to accurately reference the function or class name.

Suggested change
assert not ctx.needs_input_grad[3], "FlashDynamicMaskAttention does not support mask gradient yet"
assert not ctx.needs_input_grad[3], "FlashAttnFunc does not support mask gradient yet"

Copilot uses AI. Check for mistakes.
Updates class name from FlashAttnFunc to FlashDMAttnFunc for consistency with the flash_dmattn_func function name and module purpose.

Also updates corresponding error message text to use the shortened "FlashDMAttn" naming convention.
@LoserCheems LoserCheems merged commit 7bd99a8 into main Jul 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants