Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

This pull request introduces several updates to improve functionality and performance in dynamic mask attention implementations, particularly focusing on CUDA integration and mask preparation logic. Key changes include replacing the Python-based dynamic mask attention implementation with a CUDA-based one, enhancing error handling for imports, and refining mask preparation logic.

Dynamic Mask Attention Enhancements:

  • Updated calculate_zero_hold_states function to remove the causal_mask parameter, simplifying its interface across multiple functions. ([[1]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f12b9e1531698ec12c175cab7478023220ba63a1dcd52ae265381725400608beL79-R79), [[2]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f12b9e1531698ec12c175cab7478023220ba63a1dcd52ae265381725400608beL155-R155), [[3]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f12b9e1531698ec12c175cab7478023220ba63a1dcd52ae265381725400608beL211-R211), [[4]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL69-R82), [[5]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL178-R187), [[6]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL248-R265))
  • Replaced Python-based dynamic mask attention with CUDA-based implementation (flash_dma_cuda.fwd) for improved performance and memory handling. ([[1]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL197-R224), [[2]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dR284-L285))

Mask Preparation Logic Updates:

  • Enhanced prepare_dynamic_mask to ensure proper handling of active_mask initialization, including fallback logic when attention_mask is not provided. ([[1]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dR42-R43), [[2]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL49), [[3]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL69-R82))

Performance Benchmark Improvements:

  • Reintroduced head dimension variations in the run_performance_benchmark function to enable broader testing scenarios. ([benchmarks/benchmark_forward_performance.pyL532-R558](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL532-R558))

Error Handling and Import Updates:

  • Added error handling for importing the flash_dma_cuda module, including user guidance for installation issues. ([benchmarks/benchmark_forward_performance.pyL22-R30](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL22-R30))

Simplifies function signature by eliminating the causal_mask parameter that was not being utilized in the implementation, reducing unnecessary complexity in both the function definition and all call sites.
Replaces old flash_dma_cpp import with flash_dma_cuda module and adds proper error handling for import failures.

Updates function calls to use the new flash_dma_cuda.fwd API with expanded parameter list including dropout, softcap, and generator parameters.

Removes unused causal_mask parameter from calculate_zero_hold_states function and fixes active_mask initialization logic in prepare_dynamic_mask.

Re-enables commented head dimension test cases in performance benchmark configuration.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enhances dynamic mask attention by replacing the Python implementation with a CUDA-based version for improved performance and refines mask preparation and error handling for CUDA integration.

  • Removed the redundant causal_mask parameter from calculate_zero_hold_states.
  • Replaced Python-based dynamic mask attention with the CUDA-based implementation in two benchmark files.
  • Enhanced error messages for flash_dma_cuda import issues and updated head dimension variations in performance benchmarks.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
benchmarks/benchmark_forward_performance.py Updated dynamic mask attention logic and enhanced CUDA import error handling.
benchmarks/benchmark_forward_equivalence.py Synchronized changes in calculate_zero_hold_states to remove causal_mask.
Comments suppressed due to low confidence (2)

benchmarks/benchmark_forward_performance.py:82

  • Update the docstring of calculate_zero_hold_states to remove references to the removed 'causal_mask' parameter.
def calculate_zero_hold_states(value_states, dt_proj, A):

benchmarks/benchmark_forward_equivalence.py:79

  • Update the docstring of calculate_zero_hold_states to remove references to the removed 'causal_mask' parameter.
def calculate_zero_hold_states(value_states, dt_proj, A):

@LoserCheems LoserCheems merged commit 7d0dcb0 into main Jun 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants