Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Implements a Triton-based Flash Attention kernel that supports dynamic attention masks and bias matrices. The implementation provides memory-efficient attention computation with support for causal masking, custom attention patterns, and bias injection.

Key features include optimized block-wise computation, configurable block sizes to avoid memory overflow, and proper handling of sequence length padding for efficient GPU utilization.

Implements a Triton-based Flash Attention kernel that supports dynamic attention masks and bias matrices. The implementation provides memory-efficient attention computation with support for causal masking, custom attention patterns, and bias injection.

Key features include optimized block-wise computation, configurable block sizes to avoid memory overflow, and proper handling of sequence length padding for efficient GPU utilization.
@LoserCheems LoserCheems requested review from Evanwu1125, SNHuan, Copilot and wubingheng111 and removed request for Copilot July 3, 2025 13:48
@LoserCheems LoserCheems added the feature New feature request label Jul 3, 2025
@LoserCheems LoserCheems requested a review from Copilot July 3, 2025 13:52
Introduces two new attention implementation variants to complement the existing Python and CUDA versions:

- Adds Triton-based dynamic mask attention with proper error handling and graceful fallback
- Adds Flex Attention implementation for PyTorch's native attention API
- Extends benchmark suite with separate test functions for each implementation
- Updates command-line interface to support selective testing of specific implementations
- Improves import handling with informative status messages and non-blocking failures

Enables comprehensive performance and accuracy comparison across all four attention implementations while maintaining backward compatibility with existing CUDA tests.
Implements a new attention mechanism that leverages PyTorch's flex_attention API for improved performance and flexibility.

The implementation includes custom score modification with attention bias, causal masking support, and optimized kernel options for efficient computation.

Uses compile-friendly flex attention with configurable block sizes and returns both attention output and weights in the appropriate data types.
@LoserCheems LoserCheems merged commit dd4c08d into main Jul 3, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a Triton-based Flash Attention kernel with support for dynamic masks and bias matrices, a Flex Attention wrapper, and extends the forward‐equivalence benchmarks.

  • Implements _fwd_kernel and _flash_attn_forward in flash_dmattn_triton.py for high-performance, memory-efficient attention.
  • Introduces flash_dmattn_flex.py providing a Flex Attention forward helper using compile_friendly_flex_attention.
  • Updates benchmark_forward_equivalence.py to import and test the new CUDA, Triton, and Flex implementations with new CLI options.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
flash_dmattn/flash_dmattn_triton.py Added Triton JIT kernel and Python API for dynamic-mask attention
flash_dmattn/flash_dmattn_flex.py Added Flex Attention forward wrapper
benchmarks/benchmark_forward_equivalence.py Extended imports, added Triton/Flex test functions, updated CLI
Comments suppressed due to low confidence (2)

benchmarks/benchmark_forward_equivalence.py:33

  • After importing flash_dmattn_func, you should alias it to flash_attn_with_mask (or else update subsequent references) to avoid NameError when using flash_attn_with_mask.
    from flash_dmattn.flash_dmattn_triton import flash_dmattn_func

benchmarks/benchmark_forward_equivalence.py:304

  • flash_attn_with_mask is not defined here, so this check will raise a NameError. It should reference flash_dmattn_func or assign flash_attn_with_mask = flash_dmattn_func upon import.
    if flash_attn_with_mask is None:

Comment on lines +25 to +27
# It looks like you're attempting to use a Tensor in some data-dependent control flow.
# We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
# return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0
Copy link

Copilot AI Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider removing or clarifying this block of commented-out guidance and the broken link, as it may confuse future readers.

Suggested change
# It looks like you're attempting to use a Tensor in some data-dependent control flow.
# We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 .
# return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0
# Return True if q_idx is greater than or equal to kv_idx, indicating causal masking.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants