-
Notifications
You must be signed in to change notification settings - Fork 39
Add backend selection and dynamic mask attention support #63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Introduces extras_require configuration to allow users to install specific backends (triton, flex, all, dev, test) without requiring full CUDA compilation. Implements auto-detection logic that skips CUDA build when users explicitly request only Triton or Flex backends, reducing installation time and complexity for users who don't need CUDA acceleration. Maintains backward compatibility while providing more granular control over dependencies and build processes.
Initializes the flash_dmattn package with automatic backend selection between CUDA, Triton, and Flex implementations. Provides graceful fallback mechanism that prioritizes CUDA for performance, then Triton and Flex as alternatives. Includes runtime availability checks and clear error messages for missing dependencies. Enables users to explicitly specify backends or rely on automatic selection based on available installations.
Renames imported functions to use consistent naming convention with 'dmattn_func' suffix across Triton and Flex Attention implementations. Updates function call parameters to use positional arguments instead of keyword arguments for cleaner code. Removes hard exit on CUDA import failure to allow graceful degradation when some implementations are unavailable.
Expands benchmark to compare Flash Attention against CUDA, Triton, and Flex implementations of Dynamic Mask Attention. Introduces modular testing framework allowing selective benchmarking of specific implementations or head-to-head comparisons. Updates function signatures to return timing measurements directly from kernels for more accurate performance metrics. Renames variables for clarity (dt_states to zoh_states, active_mask to attn_mask) and adds comprehensive error handling for missing implementations. Includes new test configurations for window size variations and non-causal attention patterns. Provides detailed performance analysis with implementation-specific speedup calculations and overhead comparisons between different approaches.
Streamlines the detection of Triton/Flex-only installations by using simpler string matching instead of complex pattern checking. Removes unnecessary check for plain installations since the core logic focuses on whether specific extras are requested without all/dev variants.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Introduce optional backend selection (Triton, Flex, CUDA) and dynamic mask attention with automatic detection, along with benchmark updates to measure performance across all backends.
- Added
should_skip_cuda_build()tosetup.pyfor auto-skipping CUDA compilation when only Triton/Flex extras are requested. - Exposed
get_available_backends()and unifiedflash_dmattn_funcdispatcher in__init__.py. - Expanded benchmarks to include Triton and Flex implementations, return execution time from functions, and parameterize test types.
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
| setup.py | Auto-detect CUDA skip, add extras_require for Triton/Flex/all/dev/test and package data. |
| flash_dmattn/init.py | Export backend flags, helper get_available_backends(), and unified flash_dmattn_func. |
| benchmarks/benchmark_forward_performance.py | Add time/measuring support in CUDA/Triton/Flex benchmarks; parameterize test_type; update imports. |
| benchmarks/benchmark_forward_equivalence.py | Correct imports and dispatcher for Triton/Flex forward-equivalence tests. |
Comments suppressed due to low confidence (4)
benchmarks/benchmark_forward_performance.py:850
- [nitpick] The new parameters
test_type,num_runs, andwarmup_runsare not documented in the function docstring. Consider updating or adding a docstring to describe their purpose.
def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):
benchmarks/benchmark_forward_performance.py:188
- The
timemodule is used here but not imported at the top of the file, which will raise a NameError. Please addimport time.
start_time = time.time()
benchmarks/benchmark_forward_performance.py:641
- The
gcmodule is used here but not imported at the top of the file, leading to a NameError. Please addimport gc.
gc.collect()
benchmarks/benchmark_forward_performance.py:271
- Verify that the argument order passed to
flash_dmattn_cuda.fwdmatches the native CUDA extension signature (e.g., bias before mask or vice versa). A mismatch can silently produce incorrect outputs.
result = flash_dmattn_cuda.fwd( # type: ignore
Introduce optional backend selection for Triton and Flex, enabling users to skip CUDA builds when not needed. Implement dynamic mask attention with automatic backend detection and graceful fallbacks. Standardize function naming for consistency across implementations while maintaining backward compatibility.