Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions csrc/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
// }
// if (cute::thread0()) { print(tSrK); }
FLASH_NAMESPACE::gemm(
FLASH_NAMESPACE::sparse_gemm(
acc_s,
tSrQ, tSrK, tSsQ, tSsK,
tSrQ, tSrK, tSsQ, tSsK, tSrMask,
tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV,
smem_thr_copy_QdO, smem_thr_copy_KV
Expand Down Expand Up @@ -699,9 +699,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in

// if (cute::thread0()) { print(dP_sum); }

FLASH_NAMESPACE::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
acc_dp,
tdPrdO, tdPrV, tdPsdO, tdPsV,
tdPrdO, tdPrV, tdPsdO, tdPsV, tSrMask,
tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV,
smem_thr_copy_QdO, smem_thr_copy_KV
Expand Down Expand Up @@ -774,9 +774,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
// FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
FLASH_NAMESPACE::gemm(
FLASH_NAMESPACE::sparse_gemm(
acc_dv,
tdVrPt, tdVrdO, tdVsPt, tdVsdOt,
tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tSrMask,
tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt,
smem_thr_copy_PdSt, smem_thr_copy_QdOt
Expand Down Expand Up @@ -811,9 +811,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
}
}

FLASH_NAMESPACE::gemm(
FLASH_NAMESPACE::sparse_gemm(
acc_dq,
tdQrdS, tdQrKt, tdQsdS, tdQsKt,
tdQrdS, tdQrKt, tdQsdS, tdQsKt, tSrMask,
tiled_mma_dq,
smem_tiled_copy_dS, smem_tiled_copy_Kt,
smem_thr_copy_dS, smem_thr_copy_Kt
Expand Down Expand Up @@ -866,9 +866,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
}

FLASH_NAMESPACE::gemm(
FLASH_NAMESPACE::sparse_gemm(
acc_dk,
tdKrdSt, tdKrQt, tdKsdSt, tdKsQt,
tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tSrMask,
tiled_mma_dkv,
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt,
smem_thr_copy_PdSt, smem_thr_copy_QdOt
Expand Down
4 changes: 2 additions & 2 deletions csrc/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ __forceinline__ __device__ void sparse_gemm(
) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
Comment on lines +188 to +189
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting out static assertions removes important compile-time safety checks. Consider replacing these with runtime checks or conditional assertions that only apply when mask dimensions are expected to match.

Suggested change
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
if (size<1>(tCrA) != size<1>(tCrM)) {
printf("Error: tCrA MMA_M (%d) does not match tCrM MMA_M (%d)\n", int(size<1>(tCrA)), int(size<1>(tCrM)));
asm("trap;");
}
if (size<1>(tCrB) != size<2>(tCrM)) {
printf("Error: tCrB MMA_N (%d) does not match tCrM MMA_N (%d)\n", int(size<1>(tCrB)), int(size<2>(tCrM)));
asm("trap;");
}

Copilot uses AI. Check for mistakes.
Comment on lines +188 to +189
Copy link

Copilot AI Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting out static assertions removes important compile-time safety checks. Consider replacing these with runtime checks or conditional assertions that only apply when mask dimensions are expected to match.

Suggested change
// CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
// CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N
assert(size<1>(tCrA) == size<1>(tCrM)); // MMA_M
assert(size<1>(tCrB) == size<2>(tCrM)); // MMA_N

Copilot uses AI. Check for mistakes.
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
auto tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
Expand Down