From 3f6d72cb23995213698e0731c3e9452664800d55 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 21 Jun 2025 11:51:23 +0800 Subject: [PATCH 1/3] Optimizes sparse GEMM by removing per-MMA branching Replaces per-MMA block activation checks with global activation detection to reduce warp divergence and improve performance. Removes template parameter for number of warps and simplifies tensor parameter list by eliminating unused shared memory mask tensor. Uses thread synchronization to ensure consistent branching behavior across all threads in the cooperative thread array, avoiding divergent execution paths that can hurt GPU performance. --- csrc/src/utils.h | 132 ++++++++++++++++++++--------------------------- 1 file changed, 55 insertions(+), 77 deletions(-) diff --git a/csrc/src/utils.h b/csrc/src/utils.h index 37bb4b7..79d91fa 100644 --- a/csrc/src/utils.h +++ b/csrc/src/utils.h @@ -162,12 +162,12 @@ __forceinline__ __device__ void gemm( //////////////////////////////////////////////////////////////////////////////////////////////////// -template __forceinline__ __device__ void sparse_gemm( - Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, Tensor5 const &tCrAM, Tensor6 const &tCsAM, + Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, Tensor5 const &tCrAM, TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B ) { @@ -180,34 +180,30 @@ __forceinline__ __device__ void sparse_gemm( CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - bool mma_active[kNWarps] = {}; // MMA - // Considering the characteristics of MMA and the chain of thoughts, - // when there is any activated element in the query row or key column, - // we will mark the MMA block as activated. + // Check if any element in the entire active mask is non-zero + // Use thread-local computation then sync across all threads in the CTA + bool local_any_active = false; #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - mma_active[mma] = false; + for (int mma = 0; mma < size<0>(tCrAM) && !local_any_active; ++mma) { #pragma unroll - for (int m = 0; m < size<1>(active_mask); ++m) { + for (int m = 0; m < size<1>(tCrAM) && !local_any_active; ++m) { #pragma unroll - for (int n = 0; n < size<2>(active_mask); ++n) { - if (active_mask(mma, m, n)) { - mma_active[mma] = true; - goto mma_active_found; - } + for (int n = 0; n < size<2>(tCrAM) && !local_any_active; ++n) { + // Use direct comparison to avoid potential branching + local_any_active |= (tCrAM(mma, m, n) > 0); } } - mma_active_found:; } + // Ensure all threads in the CTA have the same any_active value to avoid warp divergence + bool any_active = __syncthreads_or(local_any_active); if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } if (!B_in_regs) { - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, _0{}), tCrB_copy_view(mma, _, _0{})); - } else { - cute::clear(tCrB_copy_view(mma, _, _0{})); - } + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view); } } #pragma unroll @@ -215,28 +211,18 @@ __forceinline__ __device__ void sparse_gemm( if (i < size<2>(tCrA) - 1) { if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } if (!B_in_regs) { - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, i + 1), tCrB_copy_view(mma, _, i + 1)); - } else { - cute::clear(tCrB_copy_view(mma, _, i + 1)); - } - + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view(_, _, i + 1)); } } } - // We must create a view to match `TiledMma` layout. - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { // MMA - if (mma_active[mma]) { - cute::gemm( - tiled_mma, - tCrA(mma, _, i), // (MMA_M, MMA_K) - tCrB(mma, _, i), // (MMA_N, MMA_K) - acc(mma, _, _) // (MMA_M, MMA_N) - ); - } + // Only perform GEMM if there are any active elements + if (any_active) { + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } } @@ -268,8 +254,7 @@ __forceinline__ __device__ void gemm_rs( //////////////////////////////////////////////////////////////////////////////////////////////////// -template __forceinline__ __device__ void sparse_gemm_rs( @@ -277,61 +262,54 @@ __forceinline__ __device__ void sparse_gemm_rs( Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, - Tensor4 const &active_mask, + Tensor4 const &tCrAM, TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, ThrCopy smem_thr_copy_B ) { CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N - CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(active_mask)); // MMA_M - CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(active_mask)); // MMA_N CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K // Retile B for thread-wise copy from shared memory to registers auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N - // Check if any row or column in the MMA block is active. - bool mma_active[kNWarps] = {}; // MMA + // Check if any element in the entire active mask is non-zero + // Use thread-local computation then sync across all threads in the CTA + bool local_any_active = false; #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - mma_active[mma] = false; + for (int mma = 0; mma < size<0>(tCrAM) && !local_any_active; ++mma) { #pragma unroll - for (int m = 0; m < size<1>(active_mask); ++m) { + for (int m = 0; m < size<1>(tCrAM) && !local_any_active; ++m) { #pragma unroll - for (int n = 0; n < size<2>(active_mask); ++n) { - if (active_mask(mma, m, n)) { - mma_active[mma] = true; - goto mma_active_found; - } + for (int n = 0; n < size<2>(tCrAM) && !local_any_active; ++n) { + // Use direct comparison to avoid potential branching + local_any_active |= (tCrAM(mma, m, n) > 0); } } - mma_active_found:; } - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, _0{}), tCrB_copy_view(mma, _, _0{})); - } else { - cute::clear(tCrB_copy_view(mma, _, _0{})); - } + // Ensure all threads in the CTA have the same any_active value to avoid warp divergence + bool any_active = __syncthreads_or(local_any_active); + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view); } #pragma unroll for (int i = 0; i < size<2>(tCrA); ++i) { if (i < size<2>(tCrA) - 1) { - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::copy(smem_tiled_copy_B, tCsB(mma, _, i + 1), tCrB_copy_view(mma, _, i + 1)); - } else { - cute::clear(tCrB_copy_view(mma, _, i + 1)); - } + if (any_active) { + // If any MMA block is active, load normally like dense gemm + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } else { + // If no MMA block is active, clear all registers + cute::clear(tCrB_copy_view(_, _, i + 1)); } } - #pragma unroll - for (int mma = 0; mma < size<0>(active_mask); ++mma) { - if (mma_active[mma]) { - cute::gemm(tiled_mma, tCrA(mma, _, i), tCrB(mma, _, i), acc(mma, _, _)); - } + // Only perform GEMM if there are any active elements + if (any_active) { + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); } } } From 08ef0c28f1f524e28298ef3bc417296bba693be0 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 21 Jun 2025 11:51:40 +0800 Subject: [PATCH 2/3] Enables sparse matrix multiplication in attention Replaces regular GEMM operations with sparse GEMM variants throughout the attention computation kernel. Updates both the key-query multiplication and the attention-value multiplication to use sparse operations, incorporating active mask indices for improved performance on sparse attention patterns. Removes TODO comments and activates previously commented sparse multiplication parameters. --- csrc/src/flash_fwd_kernel.h | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 921e3df..398a2d6 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -423,12 +423,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi } cute::cp_async_fence(); - // TODO: support sparse general matrix multiplication - FLASH_NAMESPACE::gemm( + // Use sparse general matrix multiplication + FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, - // tSrAM, tAMsAM, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -497,11 +496,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); // if (cute::thread0()) { print(tOrP); } - // TODO: support sparse general matrix multiplication with register accumulation - FLASH_NAMESPACE::gemm_rs( + // Use sparse general matrix multiplication with register accumulation for V as well + FLASH_NAMESPACE::sparse_gemm_rs( acc_o, tOrP, tOrVt, tOsVt, - // tSrAM, tAMsAM, // Apply the same mask for sparse V matrix multiplication + tSrAM, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); // if (cute::thread0()) { print(scores); } @@ -529,11 +528,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi FLASH_NAMESPACE::copy(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); cute::cp_async_fence(); - FLASH_NAMESPACE::gemm( + FLASH_NAMESPACE::sparse_gemm( acc_s, tSrQ, - tSrK, tSsQ, tSsK, - // tActiveMask, // Active key indices for sparse K matrix multiplication + tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); @@ -596,10 +594,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs(rP.layout())); - FLASH_NAMESPACE::gemm_rs( + // Use sparse general matrix multiplication with register accumulation for V as well + FLASH_NAMESPACE::sparse_gemm_rs( acc_o, tOrP, tOrVt, tOsVt, - // tActiveMask, // apply the same mask for sparse V matrix multiplication + tSrAM, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); } From 82cc6a90c623778bf78d197735e979096a6e6341 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Sat, 21 Jun 2025 11:53:01 +0800 Subject: [PATCH 3/3] Refactors function parameter formatting for consistency Combines multi-line parameter lists into single lines to improve code readability and maintain consistent formatting style across the codebase. --- csrc/src/flash_fwd_kernel.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 398a2d6..42e352f 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -499,8 +499,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Use sparse general matrix multiplication with register accumulation for V as well FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, - tSrAM, // Apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); // if (cute::thread0()) { print(scores); } @@ -597,8 +596,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Use sparse general matrix multiplication with register accumulation for V as well FLASH_NAMESPACE::sparse_gemm_rs( acc_o, - tOrP, tOrVt, tOsVt, - tSrAM, // Apply the same mask for sparse V matrix multiplication + tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication tiled_mma, smem_tiled_copy_V, smem_thr_copy_V ); }