diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 921e3df..42e352f 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -423,12 +423,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } cute::cp_async_fence(); - // TODO: support sparse general matrix multiplication - FLASH_NAMESPACE::gemm( + // Use sparse general matrix multiplication + FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, - // tSrAM, tAMsAM, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -497,11 +496,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } - // TODO: support sparse general matrix multiplication with register accumulation - FLASH_NAMESPACE::gemm_rs( + // Use sparse general matrix multiplication with register accumulation for V as well + FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, - // tSrAM, tAMsAM, // Apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); // if (cute::thread0()) { print(scores); } @@ -529,11 +527,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, - // tActiveMask, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -596,10 +593,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - FLASH_NAMESPACE::gemm_rs( + // Use sparse general matrix multiplication with register accumulation for V as well + FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, - // tActiveMask, // apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); } diff --git a/csrc/src/utils.h b/csrc/src/utils.h index 37bb4b7..79d91fa 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -162,12 +162,12 @@ __forceinline__ __device__ void gemm( //////////////////////////////////////////////////////////////////////////////////////////////////// -template __forceinline__ __device__ void sparse_gemm( - Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, Tensor5 const &tCrAM, Tensor6 const &tCsAM, + Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, Tensor5 const &tCrAM, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B ) { @@ -180,34 +180,30 @@ __forceinline__ __device__ void sparse_gemm( CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - bool mma_active[kNWarps] = {}; // MMA - // Considering the characteristics of MMA and the chain of thoughts, - // when there is any activated element in the query row or key column, - // we will mark the MMA block as activated. + // Check if any element in the entire active mask is non-zero + // Use thread-local computation then sync across all threads in the CTA + bool local_any_active = false; #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - mma_active[mma] = false; + for (int mma = 0; mma < size<0>(tCrAM) && !local_any_active; ++mma) { #pragma unroll - for (int m = 0; m < size<1>(active_mask); ++m) { + for (int m = 0; m < size<1>(tCrAM) && !local_any_active; ++m) { #pragma unroll - for (int n = 0; n < size<2>(active_mask); ++n) { - if (active_mask(mma, m, n)) { - mma_active[mma] = true; - goto mma_active_found; - } + for (int n = 0; n < size<2>(tCrAM) && !local_any_active; ++n) { + // Use direct comparison to avoid potential branching + local_any_active |= (tCrAM(mma, m, n) > 0); } } - mma_active_found:; } + // Ensure all threads in the CTA have the same any_active value to avoid warp divergence + bool any_active = __syncthreads_or(local_any_active); if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!B_in_regs) { - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, _0{}), tCrB_copy_view(mma, _, _0{})); - } else { - cute::clear(tCrB_copy_view(mma, _, _0{})); - } + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view); } } #pragma unroll @@ -215,28 +211,18 @@ __forceinline__ __device__ void sparse_gemm( if (i < size<2>(tCrA) - 1) { if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!B_in_regs) { - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, i + 1), tCrB_copy_view(mma, _, i + 1)); - } else { - cute::clear(tCrB_copy_view(mma, _, i + 1)); - } - + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view(_, _, i + 1)); } } } - // We must create a view to match `TiledMma` layout. - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { // MMA - if (mma_active[mma]) { - cute::gemm( - tiled_mma, - tCrA(mma, _, i), // (MMA_M, MMA_K) - tCrB(mma, _, i), // (MMA_N, MMA_K) - acc(mma, _, _) // (MMA_M, MMA_N) - ); - } + // Only perform GEMM if there are any active elements + if (any_active) { + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } } @@ -268,8 +254,7 @@ __forceinline__ __device__ void gemm_rs( //////////////////////////////////////////////////////////////////////////////////////////////////// -template __forceinline__ __device__ void sparse_gemm_rs( @@ -277,61 +262,54 @@ __forceinline__ __device__ void sparse_gemm_rs( Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - Tensor4 const &active_mask, + Tensor4 const &tCrAM, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B ) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(active_mask)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(active_mask)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K // Retile B for thread-wise copy from shared memory to registers auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - // Check if any row or column in the MMA block is active. - bool mma_active[kNWarps] = {}; // MMA + // Check if any element in the entire active mask is non-zero + // Use thread-local computation then sync across all threads in the CTA + bool local_any_active = false; #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - mma_active[mma] = false; + for (int mma = 0; mma < size<0>(tCrAM) && !local_any_active; ++mma) { #pragma unroll - for (int m = 0; m < size<1>(active_mask); ++m) { + for (int m = 0; m < size<1>(tCrAM) && !local_any_active; ++m) { #pragma unroll - for (int n = 0; n < size<2>(active_mask); ++n) { - if (active_mask(mma, m, n)) { - mma_active[mma] = true; - goto mma_active_found; - } + for (int n = 0; n < size<2>(tCrAM) && !local_any_active; ++n) { + // Use direct comparison to avoid potential branching + local_any_active |= (tCrAM(mma, m, n) > 0); } } - mma_active_found:; } - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, _0{}), tCrB_copy_view(mma, _, _0{})); - } else { - cute::clear(tCrB_copy_view(mma, _, _0{})); - } + // Ensure all threads in the CTA have the same any_active value to avoid warp divergence + bool any_active = __syncthreads_or(local_any_active); + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, i + 1), tCrB_copy_view(mma, _, i + 1)); - } else { - cute::clear(tCrB_copy_view(mma, _, i + 1)); - } + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view(_, _, i + 1)); } } - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::gemm(tiled_mma, tCrA(mma, _, i), tCrB(mma, _, i), acc(mma, _, _)); - } + // Only perform GEMM if there are any active elements + if (any_active) { + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } }