-
Notifications
You must be signed in to change notification settings - Fork 39
Adds Flash Attention implementation with dynamic masking #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Implements a Triton-based Flash Attention kernel that supports dynamic attention masks and bias matrices. The implementation provides memory-efficient attention computation with support for causal masking, custom attention patterns, and bias injection. Key features include optimized block-wise computation, configurable block sizes to avoid memory overflow, and proper handling of sequence length padding for efficient GPU utilization.
Introduces two new attention implementation variants to complement the existing Python and CUDA versions: - Adds Triton-based dynamic mask attention with proper error handling and graceful fallback - Adds Flex Attention implementation for PyTorch's native attention API - Extends benchmark suite with separate test functions for each implementation - Updates command-line interface to support selective testing of specific implementations - Improves import handling with informative status messages and non-blocking failures Enables comprehensive performance and accuracy comparison across all four attention implementations while maintaining backward compatibility with existing CUDA tests.
Implements a new attention mechanism that leverages PyTorch's flex_attention API for improved performance and flexibility. The implementation includes custom score modification with attention bias, causal masking support, and optimized kernel options for efficient computation. Uses compile-friendly flex attention with configurable block sizes and returns both attention output and weights in the appropriate data types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a Triton-based Flash Attention kernel with support for dynamic masks and bias matrices, a Flex Attention wrapper, and extends the forward‐equivalence benchmarks.
- Implements
_fwd_kerneland_flash_attn_forwardinflash_dmattn_triton.pyfor high-performance, memory-efficient attention. - Introduces
flash_dmattn_flex.pyproviding a Flex Attention forward helper usingcompile_friendly_flex_attention. - Updates
benchmark_forward_equivalence.pyto import and test the new CUDA, Triton, and Flex implementations with new CLI options.
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| flash_dmattn/flash_dmattn_triton.py | Added Triton JIT kernel and Python API for dynamic-mask attention |
| flash_dmattn/flash_dmattn_flex.py | Added Flex Attention forward wrapper |
| benchmarks/benchmark_forward_equivalence.py | Extended imports, added Triton/Flex test functions, updated CLI |
Comments suppressed due to low confidence (2)
benchmarks/benchmark_forward_equivalence.py:33
- After importing
flash_dmattn_func, you should alias it toflash_attn_with_mask(or else update subsequent references) to avoid NameError when usingflash_attn_with_mask.
from flash_dmattn.flash_dmattn_triton import flash_dmattn_func
benchmarks/benchmark_forward_equivalence.py:304
flash_attn_with_maskis not defined here, so this check will raise a NameError. It should referenceflash_dmattn_funcor assignflash_attn_with_mask = flash_dmattn_funcupon import.
if flash_attn_with_mask is None:
| # It looks like you're attempting to use a Tensor in some data-dependent control flow. | ||
| # We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . | ||
| # return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0 |
Copilot
AI
Jul 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider removing or clarifying this block of commented-out guidance and the broken link, as it may confuse future readers.
| # It looks like you're attempting to use a Tensor in some data-dependent control flow. | |
| # We don't support that yet, please shout over at https://github.com/pytorch/functorch/issues/257 . | |
| # return q_idx >= kv_idx and attn_mask[batch_idx][head_idx][q_idx][kv_idx] > 0 | |
| # Return True if q_idx is greater than or equal to kv_idx, indicating causal masking. |
Implements a Triton-based Flash Attention kernel that supports dynamic attention masks and bias matrices. The implementation provides memory-efficient attention computation with support for causal masking, custom attention patterns, and bias injection.
Key features include optimized block-wise computation, configurable block sizes to avoid memory overflow, and proper handling of sequence length padding for efficient GPU utilization.