diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 2db3e99..b44a185 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -232,7 +232,7 @@ std::tuple set_params_splitkv( ) { // This needs to match with run_mha_fwd_splitkv_dispatch - const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int block_n = head_size <= 64 ? 64 : (head_size < 128 ? 64 : 32); const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. // In any case we don't expect seqlen_q to be larger than 64 for inference. @@ -259,9 +259,9 @@ std::tuple set_params_splitkv( // See: https://github.com/SmallDoges/flash-dmattn/issues/47 // Regardless of how it is set externally, always set num_splits back to 1. // This is to avoid the extra memory overhead of Split-KV. - params.num_splits = 1; - softmax_lse_accum.reset(); - out_accum.reset(); + // params.num_splits = 1; + // softmax_lse_accum.reset(); + // out_accum.reset(); return std::make_tuple(softmax_lse_accum, out_accum); } diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index d49653a..aea9604 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -707,8 +707,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons const int n_block_min = n_split_idx * n_blocks_per_split; int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); if (Is_causal) { - n_block_max = std::min(n_block_max, - cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + n_block_max = std::min( + n_block_max, + cute::ceil_div( + (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, + kBlockN + ) + ); } if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 // We exit early and write 0 to gOaccum and -inf to gLSEaccum. @@ -863,9 +868,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto thr_mma = tiled_mma.get_thread_slice(tidx); Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K) Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor tSrMask = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) Tensor tSrBias = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_N) - Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N) Tensor acc_o = partition_fragment_C(tiled_mma, Shape, Int>{}); // (MMA, MMA_M, MMA_K) // Copy Atom retiling @@ -875,15 +880,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); Tensor tSsK = smem_thr_copy_K.partition_S(sK); + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx); Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask); auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx); Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias); - auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); - auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); - Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); // PREDICATES // Construct identity layout for sQ and sK diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index b5241bb..af992e0 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -38,9 +38,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn_splitkv(params); + FLASH_NAMESPACE::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If head dim > 128, set IsEvenMNConst to false to reduce number of templates auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -83,7 +83,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // int ctas_per_sm; // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + // printf("run_flash_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); }); @@ -104,26 +104,23 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { - // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // auto kernel = &flash_fwd_splitkv_kernel; - // printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // auto kernel = &flash_fwd_splitkv_kernel; + // printf("run_flash_splitkv_fwd: Split = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("run_flash_splitkv_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }); }); }); @@ -158,9 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions - // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, - // and for headdim 192 with block size 64 x 128. - constexpr static int kBlockN = Headdim <= 64 ? 128 : (Headdim <= 128 ? 64 : 32); + constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32); run_flash_splitkv_fwd, Is_causal>(params, stream); }