From 489f698b36d816cbd0097a29fbaf934e56c0c9d4 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Mon, 23 Jun 2025 14:29:34 +0800 Subject: [PATCH] Adds dropout support and softcap feature to flash attention Removes causal attention constraint to enable broader kernel usage Integrates dropout parameter throughout kernel dispatch functions and adds conditional logic based on dropout state Introduces softcap switching mechanism for attention score capping Adds split-KV kernel implementation with combine functionality for improved memory efficiency Optimizes kernel selection based on dropout state - uses different block configurations when dropout is enabled versus disabled Updates shared memory size calculation and removes mask-specific memory references --- csrc/src/flash_fwd_launch_template.h | 220 +++++++++++++++++++++------ 1 file changed, 174 insertions(+), 46 deletions(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index e86e5ce..3b330ed 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -32,7 +32,6 @@ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) - static_assert(!(Is_causal)); // Enforce constraints FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH @@ -52,10 +51,10 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { - const size_t smem_size = Kernel_traits::kSmemSizeWithMask; - // printf("smem_size = %d\n", smem_size); + const size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d (includes mask memory)\n", int(smem_size)); // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. // https://github.com/kokkos/kokkos-kernels/issues/349 @@ -66,38 +65,126 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; const bool return_softmax = params.p_ptr != nullptr; - BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +} + +template +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); }); }); }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<<>>(params); + } + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + } +} + +template +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + run_flash_splitkv_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - run_flash_fwd, Is_causal>(params, stream); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 256) is 27% slower for seqlen=2k - // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_causal>(params, stream); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + }); } template @@ -105,16 +192,23 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - } else { - run_flash_fwd, Is_causal>(params, stream); - } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); } template @@ -122,23 +216,51 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_causal>(params, stream); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy } else { - run_flash_fwd, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - } else { - run_flash_fwd, Is_causal>(params, stream); - } + }); } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - run_flash_fwd, Is_causal>(params, stream); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); } template @@ -155,13 +277,19 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - // For A100, we want to run with 128 x 64 (128KB smem). - // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_causal>(params, stream); - } + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + }); } -} // namespace FLASH_NAMESPACE \ No newline at end of file +} // namespace FLASH_NAMESPACE