Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Aug 26, 2025

Fix: #72
Disable static assertions for mask tensor dimensions and replace gemm calls with sparse_gemm to support mask parameters in matrix multiplication operations. This change enhances efficiency by allowing computations to skip masked-out regions.

Comments out dimension validation checks for the mask tensor in sparse GEMM operation.

Removes compile-time assertions that enforce size matching between input tensors and mask tensor dimensions, likely to accommodate different mask tensor layouts or sizes in sparse matrix operations.
Adds mask parameter support to matrix multiplication operations in the backward kernel.

Updates all gemm function calls to use sparse_gemm variant, which accepts an additional mask tensor parameter for handling sparse attention patterns.

Enables more efficient computation by skipping masked-out regions during matrix operations.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR enables sparse GEMM operations by adding mask tensor support to matrix multiplication computations. The changes allow the system to skip computations in masked-out regions, improving efficiency for sparse operations.

  • Disables static assertions that enforce mask tensor dimension compatibility
  • Replaces all gemm calls with sparse_gemm and adds mask tensor parameters
  • Updates the backward kernel computations to use sparse operations throughout

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
csrc/src/utils.h Comments out static assertions for mask tensor dimension validation
csrc/src/flash_bwd_kernel.h Replaces gemm calls with sparse_gemm and adds tSrMask parameter to all matrix operations

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Comment on lines +188 to +189
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting out static assertions removes important compile-time safety checks. Consider replacing these with runtime checks or conditional assertions that only apply when mask dimensions are expected to match.

Suggested change
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
if (size<1>(tCrA) != size<1>(tCrM)) {
printf("Error: tCrA MMA_M (%d) does not match tCrM MMA_M (%d)\n", int(size<1>(tCrA)), int(size<1>(tCrM)));
asm("trap;");
}
if (size<1>(tCrB) != size<2>(tCrM)) {
printf("Error: tCrB MMA_N (%d) does not match tCrM MMA_N (%d)\n", int(size<1>(tCrB)), int(size<2>(tCrM)));
asm("trap;");
}

Copilot uses AI. Check for mistakes.
Comment on lines +188 to +189
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting out static assertions removes important compile-time safety checks. Consider replacing these with runtime checks or conditional assertions that only apply when mask dimensions are expected to match.

Suggested change
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
assert(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
assert(size<1>(tCrB) == size<2>(tCrM)); // MMA_N

Copilot uses AI. Check for mistakes.
@LoserCheems
Copy link
Collaborator Author

LGTM, the backward of the CUDA kernel has finally been perfected.

@LoserCheems LoserCheems merged commit b87544f into main Aug 26, 2025
@LoserCheems LoserCheems deleted the Support-backward branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TODO List

6 participants