Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,11 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
}
cute::cp_async_fence();

// TODO: support sparse general matrix multiplication
FLASH_NAMESPACE::gemm</*kNWarps*/ /*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
// Use sparse general matrix multiplication
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s,
tSrQ,
tSrK, tSsQ, tSsK,
// tSrAM, tAMsAM, // Active key indices for sparse K matrix multiplication
tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
Expand Down Expand Up @@ -497,11 +496,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));
// if (cute::thread0()) { print(tOrP); }
// TODO: support sparse general matrix multiplication with register accumulation
FLASH_NAMESPACE::gemm_rs</*kNWarps*/>(
// Use sparse general matrix multiplication with register accumulation for V as well
FLASH_NAMESPACE::sparse_gemm_rs(
acc_o,
tOrP, tOrVt, tOsVt,
// tSrAM, tAMsAM, // Apply the same mask for sparse V matrix multiplication
tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication
tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
);
// if (cute::thread0()) { print(scores); }
Expand Down Expand Up @@ -529,11 +527,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
FLASH_NAMESPACE::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV);
cute::cp_async_fence();

FLASH_NAMESPACE::gemm</*kNWarps*/ /*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
FLASH_NAMESPACE::sparse_gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
acc_s,
tSrQ,
tSrK, tSsQ, tSsK,
// tActiveMask, // Active key indices for sparse K matrix multiplication
tSrK, tSsQ, tSsK, tSrAM, // Active key indices for sparse K matrix multiplication
tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
smem_thr_copy_Q, smem_thr_copy_K
);
Expand Down Expand Up @@ -596,10 +593,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
Tensor tOrP = make_tensor(rP.data(), FLASH_NAMESPACE::convert_layout_acc_Aregs<typename Kernel_traits::TiledMma>(rP.layout()));

FLASH_NAMESPACE::gemm_rs</*kNWarps*/>(
// Use sparse general matrix multiplication with register accumulation for V as well
FLASH_NAMESPACE::sparse_gemm_rs(
acc_o,
tOrP, tOrVt, tOsVt,
// tActiveMask, // apply the same mask for sparse V matrix multiplication
tOrP, tOrVt, tOsVt, tSrAM, // Apply the same mask for sparse V matrix multiplication
tiled_mma, smem_tiled_copy_V, smem_thr_copy_V
);
}
Expand Down
132 changes: 55 additions & 77 deletions csrc/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,12 @@ __forceinline__ __device__ void gemm(

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int kNWarps, bool A_in_regs=false, bool B_in_regs=false,
typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor5, typename Tensor6,
template <bool A_in_regs=false, bool B_in_regs=false,
typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4, typename Tensor5,
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
typename ThrCopyA, typename ThrCopyB>
__forceinline__ __device__ void sparse_gemm(
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, Tensor5 const &tCrAM, Tensor6 const &tCsAM,
Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, Tensor4 const& tCsB, Tensor5 const &tCrAM,
TiledMma tiled_mma, TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B
) {
Expand All @@ -180,63 +180,49 @@ __forceinline__ __device__ void sparse_gemm(
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
bool mma_active[kNWarps] = {}; // MMA
// Considering the characteristics of MMA and the chain of thoughts,
// when there is any activated element in the query row or key column,
// we will mark the MMA block as activated.
// Check if any element in the entire active mask is non-zero
// Use thread-local computation then sync across all threads in the CTA
bool local_any_active = false;
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) {
mma_active[mma] = false;
for (int mma = 0; mma < size<0>(tCrAM) && !local_any_active; ++mma) {
#pragma unroll
for (int m = 0; m < size<1>(active_mask); ++m) {
for (int m = 0; m < size<1>(tCrAM) && !local_any_active; ++m) {
#pragma unroll
for (int n = 0; n < size<2>(active_mask); ++n) {
if (active_mask(mma, m, n)) {
mma_active[mma] = true;
goto mma_active_found;
}
for (int n = 0; n < size<2>(tCrAM) && !local_any_active; ++n) {
// Use direct comparison to avoid potential branching
local_any_active |= (tCrAM(mma, m, n) > 0);
}
}
mma_active_found:;
}
// Ensure all threads in the CTA have the same any_active value to avoid warp divergence
bool any_active = __syncthreads_or(local_any_active);
Copy link

Copilot AI Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using __syncthreads_or for a CTA-wide reduction may not be supported in standard CUDA. Consider using warp-level intrinsics like __any_sync for each warp and a final combine via shared memory or use Cooperative Groups (cg::grid_group.any()) to correctly reduce across all threads in the CTA.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the __syncthreads_or is in sm_20_intrinsics.h!

if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
if (!B_in_regs) {
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) {
if (mma_active[mma]) {
cute::copy(smem_tiled_copy_B, tCsB(mma, _, _0{}), tCrB_copy_view(mma, _, _0{}));
} else {
cute::clear(tCrB_copy_view(mma, _, _0{}));
}
if (any_active) {
// If any MMA block is active, load normally like dense gemm
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
} else {
// If no MMA block is active, clear all registers
cute::clear(tCrB_copy_view);
}
}
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
if (!B_in_regs) {
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) {
if (mma_active[mma]) {
cute::copy(smem_tiled_copy_B, tCsB(mma, _, i + 1), tCrB_copy_view(mma, _, i + 1));
} else {
cute::clear(tCrB_copy_view(mma, _, i + 1));
}

if (any_active) {
// If any MMA block is active, load normally like dense gemm
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
} else {
// If no MMA block is active, clear all registers
cute::clear(tCrB_copy_view(_, _, i + 1));
}
}
}
// We must create a view to match `TiledMma` layout.
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) { // MMA
if (mma_active[mma]) {
cute::gemm(
tiled_mma,
tCrA(mma, _, i), // (MMA_M, MMA_K)
tCrB(mma, _, i), // (MMA_N, MMA_K)
acc(mma, _, _) // (MMA_M, MMA_N)
);
}
// Only perform GEMM if there are any active elements
if (any_active) {
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
}
Expand Down Expand Up @@ -268,70 +254,62 @@ __forceinline__ __device__ void gemm_rs(

////////////////////////////////////////////////////////////////////////////////////////////////////

template <int kNWarps,
typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
template <typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
typename TiledMma, typename TiledCopy,
typename ThrCopy>
__forceinline__ __device__ void sparse_gemm_rs(
Tensor0 &acc,
Tensor1 &tCrA,
Tensor2 &tCrB,
Tensor3 const& tCsB,
Tensor4 const &active_mask,
Tensor4 const &tCrAM,
TiledMma tiled_mma,
TiledCopy smem_tiled_copy_B,
ThrCopy smem_thr_copy_B
) {
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(active_mask)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(active_mask)); // MMA_N
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
// Retile B for thread-wise copy from shared memory to registers
auto tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
// Check if any row or column in the MMA block is active.
bool mma_active[kNWarps] = {}; // MMA
// Check if any element in the entire active mask is non-zero
// Use thread-local computation then sync across all threads in the CTA
bool local_any_active = false;
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) {
mma_active[mma] = false;
for (int mma = 0; mma < size<0>(tCrAM) && !local_any_active; ++mma) {
#pragma unroll
for (int m = 0; m < size<1>(active_mask); ++m) {
for (int m = 0; m < size<1>(tCrAM) && !local_any_active; ++m) {
#pragma unroll
for (int n = 0; n < size<2>(active_mask); ++n) {
if (active_mask(mma, m, n)) {
mma_active[mma] = true;
goto mma_active_found;
}
for (int n = 0; n < size<2>(tCrAM) && !local_any_active; ++n) {
// Use direct comparison to avoid potential branching
local_any_active |= (tCrAM(mma, m, n) > 0);
}
}
mma_active_found:;
}
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) {
if (mma_active[mma]) {
cute::copy(smem_tiled_copy_B, tCsB(mma, _, _0{}), tCrB_copy_view(mma, _, _0{}));
} else {
cute::clear(tCrB_copy_view(mma, _, _0{}));
}
// Ensure all threads in the CTA have the same any_active value to avoid warp divergence
bool any_active = __syncthreads_or(local_any_active);
if (any_active) {
// If any MMA block is active, load normally like dense gemm
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
} else {
// If no MMA block is active, clear all registers
cute::clear(tCrB_copy_view);
}
#pragma unroll
for (int i = 0; i < size<2>(tCrA); ++i) {
if (i < size<2>(tCrA) - 1) {
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) {
if (mma_active[mma]) {
cute::copy(smem_tiled_copy_B, tCsB(mma, _, i + 1), tCrB_copy_view(mma, _, i + 1));
} else {
cute::clear(tCrB_copy_view(mma, _, i + 1));
}
if (any_active) {
// If any MMA block is active, load normally like dense gemm
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
} else {
// If no MMA block is active, clear all registers
cute::clear(tCrB_copy_view(_, _, i + 1));
}
}
#pragma unroll
for (int mma = 0; mma < size<0>(active_mask); ++mma) {
if (mma_active[mma]) {
cute::gemm(tiled_mma, tCrA(mma, _, i), tCrB(mma, _, i), acc(mma, _, _));
}
// Only perform GEMM if there are any active elements
if (any_active) {
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
}
}
}
Expand Down