-
Notifications
You must be signed in to change notification settings - Fork 39
Remove cub submodule and add cutlass; implement FlashDynamicMaskAttention #31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Eliminates the NVIDIA cub submodule from the project configuration, likely due to deprecation or integration into CUDA toolkit, reducing external dependencies and simplifying the build process.
Includes cutlass library as a git submodule for enhanced matrix operations support. Comments out unused index calculation variables in mask optimization path to eliminate compiler warnings while preserving the logic structure for potential future use.
Implements complete forward pass functionality for flash attention with dynamic masking support, including parameter setup, kernel dispatch logic, and memory optimization through split-kv heuristics. Supports key features like dropout, softcapping, causal masking, and multi-head attention with grouped query attention optimization for single-token sequences. Provides Python bindings for integration with PyTorch tensors and CUDA operations.
Updates the benchmark to import and use the new flash_dma_cuda module instead of the previous flash_dma_cpp implementation. Adds proper error handling for the CUDA extension import with informative messages and graceful exit on failure. Refactors the CUDA attention function to use the new mha_fwd API signature with proper parameter mapping and tensor format requirements. Improves test configuration with additional test cases for multi-batch and GQA scenarios, and enhances the test runner with better reporting and exit codes. Fixes a bug in the dynamic mask preparation logic by moving active_mask initialization to the correct conditional branch.
Cleans up codebase by removing unused dynamic mask attention implementation. Eliminates API wrappers, CUDA kernels, and associated infrastructure that was no longer needed, reducing maintenance overhead and code complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes the old CUB dependency, adds CUTLASS, implements a new CUDA extension for FlashDynamicMaskAttention (with Python bindings), and updates benchmarks to use the new API.
- Swap out the NVIDIA CUB submodule and include CUTLASS.
- Add
flash_api.cppprovidingmha_fwdC++/PyTorch bindings. - Delete legacy
apply_dynamic_mask*kernels and refactor benchmarks to call the new extension.
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| csrc/src/mask.h | Commented out unused dynamic‐mask indexing logic |
| csrc/flash_api.cpp | New FlashDynamicMaskAttention implementation & binding |
| csrc/cutlass | Added CUTLASS submodule pointer |
| csrc/apply_dynamic_mask_kernel.cu | Removed legacy dynamic mask kernel |
| csrc/apply_dynamic_mask_attention_kernel.cu | Removed legacy attention kernel |
| csrc/apply_dynamic_mask_attention_api.cpp | Removed old Python API glue for dynamic mask attention |
| csrc/apply_dynamic_mask_api.cpp | Removed old dynamic mask API |
| benchmarks/benchmark_forward_equivalence.py | Updated to import flash_dma_cuda.fwd and test types |
| .gitmodules | Deleted CUB submodule entry |
Comments suppressed due to low confidence (2)
csrc/flash_api.cpp:48
- The parameter name
seqused_kappears to be a typo. Consider renaming it toseqlens_k_dorcu_seqlens_kfor clarity.
void *seqused_k,
benchmarks/benchmark_forward_equivalence.py:360
- The medium-scale GQA test for sequence length 128 was removed. Re-add a
(1, 2, 1, 128, 128, 32, True)test case to maintain coverage for grouped-query attention at that size.
(1, 1, 1, 128, 128, 32, False), # Medium scale test, non-causal mask
| // const int row_idx_base = row_idx_offset + mi * warp_row_stride; | ||
| #pragma unroll | ||
| for (int i = 0; i < size<0, 0>(tensor); ++i) { | ||
| const int row_idx = row_idx_base + i * 8; | ||
| // const int row_idx = row_idx_base + i * 8; | ||
| #pragma unroll | ||
| for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { | ||
| const int col_idx_base = col_idx_offset + nj * 8; | ||
| // const int col_idx_base = col_idx_offset + nj * 8; | ||
| #pragma unroll | ||
| for (int j = 0; j < size<1, 0>(tensor); ++j) { | ||
| const int col_idx = col_idx_base + j; | ||
| // const int col_idx = col_idx_base + j; |
Copilot
AI
Jun 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Rather than leaving large blocks of dead code commented out, remove these unused index calculations entirely to keep the codebase clean.
Eliminate the cub submodule to reduce dependencies and include the cutlass library for improved matrix operations. Implement the FlashDynamicMaskAttention C++ API with dynamic masking support and Python bindings for PyTorch. Refactor benchmarks to utilize the new CUDA extension API and enhance error handling. Remove unused dynamic mask functionality from the CUDA extension to streamline the codebase.