Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Eliminate the cub submodule to reduce dependencies and include the cutlass library for improved matrix operations. Implement the FlashDynamicMaskAttention C++ API with dynamic masking support and Python bindings for PyTorch. Refactor benchmarks to utilize the new CUDA extension API and enhance error handling. Remove unused dynamic mask functionality from the CUDA extension to streamline the codebase.

Eliminates the NVIDIA cub submodule from the project configuration, likely due to deprecation or integration into CUDA toolkit, reducing external dependencies and simplifying the build process.
Includes cutlass library as a git submodule for enhanced matrix operations support.

Comments out unused index calculation variables in mask optimization path to eliminate compiler warnings while preserving the logic structure for potential future use.
Implements complete forward pass functionality for flash attention with dynamic masking support, including parameter setup, kernel dispatch logic, and memory optimization through split-kv heuristics.

Supports key features like dropout, softcapping, causal masking, and multi-head attention with grouped query attention optimization for single-token sequences.

Provides Python bindings for integration with PyTorch tensors and CUDA operations.
Updates the benchmark to import and use the new flash_dma_cuda module instead of the previous flash_dma_cpp implementation.

Adds proper error handling for the CUDA extension import with informative messages and graceful exit on failure.

Refactors the CUDA attention function to use the new mha_fwd API signature with proper parameter mapping and tensor format requirements.

Improves test configuration with additional test cases for multi-batch and GQA scenarios, and enhances the test runner with better reporting and exit codes.

Fixes a bug in the dynamic mask preparation logic by moving active_mask initialization to the correct conditional branch.
Cleans up codebase by removing unused dynamic mask attention implementation.

Eliminates API wrappers, CUDA kernels, and associated infrastructure that was no longer needed, reducing maintenance overhead and code complexity.
@LoserCheems LoserCheems added bug Something isn't working feature New feature request labels Jun 26, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes the old CUB dependency, adds CUTLASS, implements a new CUDA extension for FlashDynamicMaskAttention (with Python bindings), and updates benchmarks to use the new API.

  • Swap out the NVIDIA CUB submodule and include CUTLASS.
  • Add flash_api.cpp providing mha_fwd C++/PyTorch bindings.
  • Delete legacy apply_dynamic_mask* kernels and refactor benchmarks to call the new extension.

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
csrc/src/mask.h Commented out unused dynamic‐mask indexing logic
csrc/flash_api.cpp New FlashDynamicMaskAttention implementation & binding
csrc/cutlass Added CUTLASS submodule pointer
csrc/apply_dynamic_mask_kernel.cu Removed legacy dynamic mask kernel
csrc/apply_dynamic_mask_attention_kernel.cu Removed legacy attention kernel
csrc/apply_dynamic_mask_attention_api.cpp Removed old Python API glue for dynamic mask attention
csrc/apply_dynamic_mask_api.cpp Removed old dynamic mask API
benchmarks/benchmark_forward_equivalence.py Updated to import flash_dma_cuda.fwd and test types
.gitmodules Deleted CUB submodule entry
Comments suppressed due to low confidence (2)

csrc/flash_api.cpp:48

  • The parameter name seqused_k appears to be a typo. Consider renaming it to seqlens_k_d or cu_seqlens_k for clarity.
    void *seqused_k,

benchmarks/benchmark_forward_equivalence.py:360

  • The medium-scale GQA test for sequence length 128 was removed. Re-add a (1, 2, 1, 128, 128, 32, True) test case to maintain coverage for grouped-query attention at that size.
        (1, 1, 1, 128, 128, 32, False),  # Medium scale test, non-causal mask

Comment on lines +90 to +99
// const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
// const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
// const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
// const int col_idx = col_idx_base + j;
Copy link

Copilot AI Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Rather than leaving large blocks of dead code commented out, remove these unused index calculations entirely to keep the codebase clean.

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit e3c8cfc into main Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working feature New feature request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants