Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Introduce optional backend selection for Triton and Flex, enabling users to skip CUDA builds when not needed. Implement dynamic mask attention with automatic backend detection and graceful fallbacks. Standardize function naming for consistency across implementations while maintaining backward compatibility.

Introduces extras_require configuration to allow users to install specific backends (triton, flex, all, dev, test) without requiring full CUDA compilation.

Implements auto-detection logic that skips CUDA build when users explicitly request only Triton or Flex backends, reducing installation time and complexity for users who don't need CUDA acceleration.

Maintains backward compatibility while providing more granular control over dependencies and build processes.
Initializes the flash_dmattn package with automatic backend selection between CUDA, Triton, and Flex implementations.

Provides graceful fallback mechanism that prioritizes CUDA for performance, then Triton and Flex as alternatives. Includes runtime availability checks and clear error messages for missing dependencies.

Enables users to explicitly specify backends or rely on automatic selection based on available installations.
Renames imported functions to use consistent naming convention with 'dmattn_func' suffix across Triton and Flex Attention implementations.

Updates function call parameters to use positional arguments instead of keyword arguments for cleaner code.

Removes hard exit on CUDA import failure to allow graceful degradation when some implementations are unavailable.

This comment was marked as outdated.

Expands benchmark to compare Flash Attention against CUDA, Triton, and Flex implementations of Dynamic Mask Attention.

Introduces modular testing framework allowing selective benchmarking of specific implementations or head-to-head comparisons. Updates function signatures to return timing measurements directly from kernels for more accurate performance metrics.

Renames variables for clarity (dt_states to zoh_states, active_mask to attn_mask) and adds comprehensive error handling for missing implementations. Includes new test configurations for window size variations and non-causal attention patterns.

Provides detailed performance analysis with implementation-specific speedup calculations and overhead comparisons between different approaches.
Streamlines the detection of Triton/Flex-only installations by using simpler string matching instead of complex pattern checking.

Removes unnecessary check for plain installations since the core logic focuses on whether specific extras are requested without all/dev variants.
@LoserCheems LoserCheems requested a review from Copilot July 7, 2025 12:15
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Introduce optional backend selection (Triton, Flex, CUDA) and dynamic mask attention with automatic detection, along with benchmark updates to measure performance across all backends.

  • Added should_skip_cuda_build() to setup.py for auto-skipping CUDA compilation when only Triton/Flex extras are requested.
  • Exposed get_available_backends() and unified flash_dmattn_func dispatcher in __init__.py.
  • Expanded benchmarks to include Triton and Flex implementations, return execution time from functions, and parameterize test types.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.

File Description
setup.py Auto-detect CUDA skip, add extras_require for Triton/Flex/all/dev/test and package data.
flash_dmattn/init.py Export backend flags, helper get_available_backends(), and unified flash_dmattn_func.
benchmarks/benchmark_forward_performance.py Add time/measuring support in CUDA/Triton/Flex benchmarks; parameterize test_type; update imports.
benchmarks/benchmark_forward_equivalence.py Correct imports and dispatcher for Triton/Flex forward-equivalence tests.
Comments suppressed due to low confidence (4)

benchmarks/benchmark_forward_performance.py:850

  • [nitpick] The new parameters test_type, num_runs, and warmup_runs are not documented in the function docstring. Consider updating or adding a docstring to describe their purpose.
def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2):

benchmarks/benchmark_forward_performance.py:188

  • The time module is used here but not imported at the top of the file, which will raise a NameError. Please add import time.
        start_time = time.time()

benchmarks/benchmark_forward_performance.py:641

  • The gc module is used here but not imported at the top of the file, leading to a NameError. Please add import gc.
        gc.collect()

benchmarks/benchmark_forward_performance.py:271

  • Verify that the argument order passed to flash_dmattn_cuda.fwd matches the native CUDA extension signature (e.g., bias before mask or vice versa). A mismatch can silently produce incorrect outputs.
        result = flash_dmattn_cuda.fwd(    # type: ignore

@LoserCheems LoserCheems merged commit 47c6403 into main Jul 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants