Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
) {

// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
const int block_n = head_size <= 64 ? 64 : (head_size < 128 ? 64 : 32);
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
Expand All @@ -259,9 +259,9 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
// See: https://github.com/SmallDoges/flash-dmattn/issues/47
// Regardless of how it is set externally, always set num_splits back to 1.
// This is to avoid the extra memory overhead of Split-KV.
params.num_splits = 1;
softmax_lse_accum.reset();
out_accum.reset();
// params.num_splits = 1;
// softmax_lse_accum.reset();
// out_accum.reset();

return std::make_tuple(softmax_lse_accum, out_accum);
}
Expand Down
17 changes: 11 additions & 6 deletions csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
const int n_block_min = n_split_idx * n_blocks_per_split;
int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split);
if (Is_causal) {
n_block_max = std::min(n_block_max,
cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
n_block_max = std::min(
n_block_max,
cute::ceil_div(
(m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q,
kBlockN
)
);
}
if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0
// We exit early and write 0 to gOaccum and -inf to gLSEaccum.
Expand Down Expand Up @@ -863,9 +868,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA, MMA_M, MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA, MMA_N, MMA_K)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
Tensor tSrMask = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
Tensor tSrBias = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA, MMA_M, MMA_N)
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K, MMA_N)
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (MMA, MMA_M, MMA_K)

// Copy Atom retiling
Expand All @@ -875,15 +880,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
auto smem_tiled_copy_Mask = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_Mask = smem_tiled_copy_Mask.get_thread_slice(tidx);
Tensor tSsMask = smem_thr_copy_Mask.partition_S(sMask);
auto smem_tiled_copy_Bias = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
auto smem_thr_copy_Bias = smem_tiled_copy_Bias.get_thread_slice(tidx);
Tensor tSsBias = smem_thr_copy_Bias.partition_S(sBias);
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);

// PREDICATES
// Construct identity layout for sQ and sK
Expand Down
49 changes: 22 additions & 27 deletions csrc/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, b
#endif
}

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) {
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split) {
#if defined(ARCH_SUPPORTS_FLASH)
FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params);
FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Is_softcap, Split>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
Expand Down Expand Up @@ -74,7 +74,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
Expand All @@ -83,7 +83,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
// printf("run_flash_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
Expand All @@ -104,26 +104,23 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If Is_local, set Is_causal to false
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split>;
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
// printf("run_flash_splitkv_fwd: Split = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap));
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("run_flash_splitkv_fwd: smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
Expand Down Expand Up @@ -158,9 +155,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
constexpr static int kBlockN = Headdim <= 64 ? 128 : (Headdim <= 128 ? 64 : 32);
constexpr static int kBlockN = Headdim <= 64 ? 64 : (Headdim < 128 ? 64 : 32);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
}

Expand Down