From 8012055ded4065e6de4c2ffa41655f54748a1ba4 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 27 Aug 2025 00:09:07 +0800 Subject: [PATCH 1/2] Disables static assertions for mask tensor dimensions Comments out dimension validation checks for the mask tensor in sparse GEMM operation. Removes compile-time assertions that enforce size matching between input tensors and mask tensor dimensions, likely to accommodate different mask tensor layouts or sizes in sparse matrix operations. --- csrc/src/utils.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/src/utils.h b/csrc/src/utils.h index 7b5ff93..81c716a 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -185,8 +185,8 @@ __forceinline__ __device__ void sparse_gemm( ) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N + // CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M + // CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K auto tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M From 146854385de2887031f1cfb57d823e7dee5fbbf5 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Wed, 27 Aug 2025 00:09:21 +0800 Subject: [PATCH 2/2] Replaces gemm calls with sparse_gemm implementation Adds mask parameter support to matrix multiplication operations in the backward kernel. Updates all gemm function calls to use sparse_gemm variant, which accepts an additional mask tensor parameter for handling sparse attention patterns. Enables more efficient computation by skipping masked-out regions during matrix operations. --- csrc/src/flash_bwd_kernel.h | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index c16b71f..66a5398 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -623,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); // } // if (cute::thread0()) { print(tSrK); } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_s, - tSrQ, tSrK, tSsQ, tSsK, + tSrQ, tSrK, tSsQ, tSsK, tSrMask, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV @@ -699,9 +699,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(dP_sum); } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dp, - tdPrdO, tdPrV, tdPsdO, tdPsV, + tdPrdO, tdPrV, tdPsdO, tdPsV, tSrMask, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV @@ -774,9 +774,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dv, - tdVrPt, tdVrdO, tdVsPt, tdVsdOt, + tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tSrMask, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt @@ -811,9 +811,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dq, - tdQrdS, tdQrKt, tdQsdS, tdQsKt, + tdQrdS, tdQrKt, tdQsdS, tdQsKt, tSrMask, tiled_mma_dq, smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt @@ -866,9 +866,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dk, - tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, + tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tSrMask, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt