diff --git a/csrc/src/flash_bwd_kernel.h b/csrc/src/flash_bwd_kernel.h index c16b71f..66a5398 100644 --- a/csrc/src/flash_bwd_kernel.h +++ b/csrc/src/flash_bwd_kernel.h @@ -623,9 +623,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k)); // } // if (cute::thread0()) { print(tSrK); } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_s, - tSrQ, tSrK, tSsQ, tSsK, + tSrQ, tSrK, tSsQ, tSsK, tSrMask, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV @@ -699,9 +699,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // if (cute::thread0()) { print(dP_sum); } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dp, - tdPrdO, tdPrV, tdPsdO, tdPsV, + tdPrdO, tdPrV, tdPsdO, tdPsV, tSrMask, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV @@ -774,9 +774,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in // FLASH_NAMESPACE::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt); // Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout()); // FLASH_NAMESPACE::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt); - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dv, - tdVrPt, tdVrdO, tdVsPt, tdVsdOt, + tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tSrMask, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt @@ -811,9 +811,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in } } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dq, - tdQrdS, tdQrKt, tdQsdS, tdQsKt, + tdQrdS, tdQrKt, tdQsdS, tdQsKt, tSrMask, tiled_mma_dq, smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt @@ -866,9 +866,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); } - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_dk, - tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, + tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tSrMask, tiled_mma_dkv, smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt diff --git a/csrc/src/utils.h b/csrc/src/utils.h index 7b5ff93..81c716a 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -185,8 +185,8 @@ __forceinline__ __device__ void sparse_gemm( ) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N + // CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(tCrM)); // MMA_M + // CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(tCrM)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K auto tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M