Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request introduces several updates to improve functionality and performance in dynamic mask attention implementations, particularly focusing on CUDA integration and mask preparation logic. Key changes include replacing the Python-based dynamic mask attention implementation with a CUDA-based one, enhancing error handling for imports, and refining mask preparation logic.
Dynamic Mask Attention Enhancements:
calculate_zero_hold_statesfunction to remove thecausal_maskparameter, simplifying its interface across multiple functions. ([[1]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f12b9e1531698ec12c175cab7478023220ba63a1dcd52ae265381725400608beL79-R79),[[2]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f12b9e1531698ec12c175cab7478023220ba63a1dcd52ae265381725400608beL155-R155),[[3]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f12b9e1531698ec12c175cab7478023220ba63a1dcd52ae265381725400608beL211-R211),[[4]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL69-R82),[[5]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL178-R187),[[6]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL248-R265))flash_dma_cuda.fwd) for improved performance and memory handling. ([[1]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL197-R224),[[2]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dR284-L285))Mask Preparation Logic Updates:
prepare_dynamic_maskto ensure proper handling ofactive_maskinitialization, including fallback logic whenattention_maskis not provided. ([[1]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dR42-R43),[[2]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL49),[[3]](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL69-R82))Performance Benchmark Improvements:
run_performance_benchmarkfunction to enable broader testing scenarios. ([benchmarks/benchmark_forward_performance.pyL532-R558](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL532-R558))Error Handling and Import Updates:
flash_dma_cudamodule, including user guidance for installation issues. ([benchmarks/benchmark_forward_performance.pyL22-R30](https://github.com/flash-algo/flash-sparse-attention/pull/36/files#diff-f1412b5998e5d0df84551a063f09794deec64831ce47e4d643b88a94e207a89dL22-R30))