From 502aed126280c9c1cea265123bec279785d470d7 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 11 Jul 2025 17:03:22 +0800 Subject: [PATCH 1/3] Refactors variable declarations for better readability Reorders tensor variable declarations to group related functionality together and improves code formatting consistency. Moves tensor fragment declarations closer to their logical usage groups and reformats conditional statements with proper indentation for enhanced code maintainability. --- csrc/src/flash_fwd_kernel.h | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index d49653a..aea9604 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -707,8 +707,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int n_block_min = n_split_idx * n_blocks_per_split; int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); if (Is_causal) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + n_block_max = std::min( + n_block_max, + cute::ceil_div( + (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, + kBlockN + ) + ); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -863,9 +868,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor tSrMask = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tSrBias = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) // Copy Atom retiling @@ -875,15 +880,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // PREDICATES // Construct identity layout for sQ and sK From b27b21fa389d4e7e57a53202859dff94dbde1460 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 11 Jul 2025 18:44:01 +0800 Subject: [PATCH 2/3] Removes Append_KV parameter from splitkv kernel Simplifies the flash forward splitkv kernel by removing the Append_KV template parameter and associated logic. This reduces template instantiation complexity and removes one level of nested switch statements. Also updates debug printf statements to include function names for better debugging clarity and adjusts block size constants for improved performance characteristics. --- csrc/src/flash_fwd_launch_template.h | 49 +++++++++++++--------------- 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index b5241bb..af992e0 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -38,9 +38,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If head dim > 128, set IsEvenMNConst to false to reduce number of templates auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -83,7 +83,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + // printf("run_flash_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -104,26 +104,23 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // printf("run_flash_splitkv_fwd: Split = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("run_flash_splitkv_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); @@ -158,9 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr static int kBlockN = Headdim <= 64 ? 128 : (Headdim <= 128 ? 64 : 32); + constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32); run_flash_splitkv_fwd, Is_causal>(params, stream); } From 6755b4d2a324ec2fab40be762b815811439cb6ed Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 11 Jul 2025 18:44:14 +0800 Subject: [PATCH 3/3] Adjusts block size parameters and disables split-KV Updates block size calculation to use smaller values for better memory efficiency. Comments out split-KV functionality to avoid extra memory overhead as mentioned in the existing comment about always setting num_splits back to 1. --- csrc/flash_api.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 2db3e99..b44a185 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -232,7 +232,7 @@ std::tuple set_params_splitkv( ) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int block_n = head_size <= 64 ? 64 : (head_size < 128 ? 64 : 32); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. @@ -259,9 +259,9 @@ std::tuple set_params_splitkv( // See: https://github.com/SmallDoges/flash-dmattn/issues/47 // Regardless of how it is set externally, always set num_splits back to 1. // This is to avoid the extra memory overhead of Split-KV. - params.num_splits = 1; - softmax_lse_accum.reset(); - out_accum.reset(); + // params.num_splits = 1; + // softmax_lse_accum.reset(); + // out_accum.reset(); return std::make_tuple(softmax_lse_accum, out_accum); }