-
Notifications
You must be signed in to change notification settings - Fork 39
Enable sparse GEMM with mask tensor support #130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Comments out dimension validation checks for the mask tensor in sparse GEMM operation. Removes compile-time assertions that enforce size matching between input tensors and mask tensor dimensions, likely to accommodate different mask tensor layouts or sizes in sparse matrix operations.
Adds mask parameter support to matrix multiplication operations in the backward kernel. Updates all gemm function calls to use sparse_gemm variant, which accepts an additional mask tensor parameter for handling sparse attention patterns. Enables more efficient computation by skipping masked-out regions during matrix operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enables sparse GEMM operations by adding mask tensor support to matrix multiplication computations. The changes allow the system to skip computations in masked-out regions, improving efficiency for sparse operations.
- Disables static assertions that enforce mask tensor dimension compatibility
- Replaces all
gemmcalls withsparse_gemmand adds mask tensor parameters - Updates the backward kernel computations to use sparse operations throughout
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| csrc/src/utils.h | Comments out static assertions for mask tensor dimension validation |
| csrc/src/flash_bwd_kernel.h | Replaces gemm calls with sparse_gemm and adds tSrMask parameter to all matrix operations |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| // CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M | ||
| // CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commenting out static assertions removes important compile-time safety checks. Consider replacing these with runtime checks or conditional assertions that only apply when mask dimensions are expected to match.
| // CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M | |
| // CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N | |
| if (size<1>(tCrA) != size<1>(tCrM)) { | |
| printf("Error: tCrA MMA_M (%d) does not match tCrM MMA_M (%d)\n", int(size<1>(tCrA)), int(size<1>(tCrM))); | |
| asm("trap;"); | |
| } | |
| if (size<1>(tCrB) != size<2>(tCrM)) { | |
| printf("Error: tCrB MMA_N (%d) does not match tCrM MMA_N (%d)\n", int(size<1>(tCrB)), int(size<2>(tCrM))); | |
| asm("trap;"); | |
| } |
| // CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M | ||
| // CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N |
Copilot
AI
Aug 26, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commenting out static assertions removes important compile-time safety checks. Consider replacing these with runtime checks or conditional assertions that only apply when mask dimensions are expected to match.
| // CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M | |
| // CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N | |
| assert(size<1>(tCrA) == size<1>(tCrM)); // MMA_M | |
| assert(size<1>(tCrB) == size<2>(tCrM)); // MMA_N |
|
LGTM, the backward of the CUDA kernel has finally been perfected. |
Fix: #72
Disable static assertions for mask tensor dimensions and replace gemm calls with sparse_gemm to support mask parameters in matrix multiplication operations. This change enhances efficiency by allowing computations to skip masked-out regions.