Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 6 additions & 15 deletions csrc/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ namespace FLASH_NAMESPACE {
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
typedef int64_t index_t;
Copy link

Copilot AI May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The global typedef for index_t now makes local alias definitions in structs redundant. Ensure any previous local definitions are removed so that all modules use the global index_t consistently.

Copilot uses AI. Check for mistakes.

////////////////////////////////////////////////////////////////////////////////////////////////////

struct QKV_params {
using index_t = int64_t;
// The QKV matrices.
void *__restrict__ q_ptr; // Query tensor [batch_size, num_heads, query_len, head_dim]
void *__restrict__ k_ptr; // Key tensor [batch_size, num_kv_heads, key_len, head_dim]
Expand All @@ -46,20 +46,12 @@ struct QKV_params {
////////////////////////////////////////////////////////////////////////////////////////////////////

struct ZeroHold_params {
using index_t = int64_t;

void *__restrict__ zero_hold_ptr; // Zero-hold states tensor [batch_size, num_kv_heads, query_len, key_len]
void *__restrict__ zero_hold_ptr; // Zero-hold states tensor [batch_size, num_kv_heads, query_len, key_len]

// The stride of the zero-hold states tensor.
index_t zero_hold_batch_stride; // Stride between batches of zero-hold states
index_t zero_hold_head_stride; // Stride between heads of zero-hold states
index_t zero_hold_query_stride; // Stride for the third dimension (query_len) of zero-hold states
// Assuming last dim (key_len) has stride 1 for the zero_hold_states_ptr

index_t causal_mask_batch_stride; // Stride between batches of causal_mask
index_t causal_mask_head_stride; // Stride for the second dimension (size 1) of causal_mask
index_t causal_mask_query_len_stride; // Stride for the third dimension (query_len) of causal_mask
// Assuming last dim (key_len) has stride 1 for the causal_mask_ptr
index_t zero_hold_batch_stride; // Stride between batches of zero-hold states
index_t zero_hold_head_stride; // Stride between heads of zero-hold states
index_t zero_hold_row_stride; // Stride for the third dimension (key_len) of zero-hold states

// The keep window size.
int keep_window_size; // Number of tokens to keep in top-k (0 means don't apply top-k)
Expand All @@ -73,7 +65,6 @@ struct Flash_fwd_params : public QKV_params, public ZeroHold_params {
void *k_ptr = nullptr;
void *v_ptr = nullptr;
void *zero_hold_ptr = nullptr;
void *causal_mask_ptr = nullptr;

// Input tensor for the bias
void *b_ptr = nullptr;
Expand Down Expand Up @@ -207,7 +198,7 @@ struct Flash_bwd_params : public Flash_fwd_params {
index_t dv_head_stride;
index_t dzero_hold_batch_stride;
index_t dzero_hold_head_stride;
index_t dzero_hold_query_stride;
index_t dzero_hold_row_stride;

// The pointer to the softmax d sum.
void *__restrict__ dsoftmax_sum;
Expand Down
579 changes: 216 additions & 363 deletions csrc/src/flash_attention_fwd_kernel.h

Large diffs are not rendered by default.

78 changes: 0 additions & 78 deletions csrc/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,6 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, b
#endif
}

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
#if defined(ARCH_SUPPORTS_FLASH)
FLASH_NAMESPACE::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
}

DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
static_assert(Log_max_splits >= 1);
FLASH_NAMESPACE::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
}

template<typename Kernel_traits, bool Is_causal>
void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
const size_t smem_size = Kernel_traits::kSmemSizeWithMask;
Expand Down Expand Up @@ -84,71 +71,6 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
});
}

template<typename Kernel_traits, bool Is_causal>
void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");

const size_t smem_size = Kernel_traits::kSmemSizeWithMask;

const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;

BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
if (params.num_splits > 1) {
// We want kBlockM to be as small as possible for more parallelism.
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
if (params.num_splits <= 2) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 4) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 8) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 16) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 32) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 64) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
} else if (params.num_splits <= 128) {
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}

template<typename T, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
// and for headdim 192 with block size 64 x 128.
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
}

template<typename T, bool Is_causal>
void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 32;
Expand Down
111 changes: 111 additions & 0 deletions csrc/src/generate_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import argparse
import itertools
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

DTYPE_MAP = {
"fp16": "cutlass::half_t",
"bf16": "cutlass::bfloat16_t",
}

SM = [80] # Sm80 kernels support up to
HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256]
IS_CAUSAL = ["false", "true"]
NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'

def get_fwd_template() -> str:
return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {{

template<>
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream) {{
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}

}} // namespace FLASH_NAMESPACE"""

def get_fwd_split_template() -> str:
return NAMESPACE_INCLUDE + """#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {{

template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params &params, cudaStream_t stream);

}} // namespace FLASH_NAMESPACE"""

def get_bwd_template() -> str:
return NAMESPACE_INCLUDE + """#include "flash_bwd_launch_template.h"

namespace FLASH_NAMESPACE {{

template<>
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params &params, cudaStream_t stream) {{
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream);
}}

}} // namespace FLASH_NAMESPACE"""

@dataclass
class Kernel:
sm: int
dtype: str
head_dim: int
is_causal: str
direction: str

@property
def template(self) -> str:
template_funcs = {
"fwd": get_fwd_template,
# "bwd": get_bwd_template,
# "fwd_split": get_fwd_split_template
}
template_func = template_funcs[self.direction]
return template_func().format(
DTYPE=DTYPE_MAP[self.dtype],
HEAD_DIM=self.head_dim,
IS_CAUSAL=self.is_causal
)

@property
def filename(self) -> str:
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu"

def get_all_kernels() -> List[Kernel]:
for direction in ["fwd"]: #, "fwd_split", "bwd"]:
for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM):
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction)

def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
prelude = """// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"\n"""
content = prelude + kernel.template
(autogen_dir / kernel.filename).write_text(content)

def main(output_dir: Optional[str]) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir)

for kernel in get_all_kernels():
write_kernel(kernel, output_dir)

if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate_kernels",
description="Generate the flash_attention kernels template instantiations",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="Where to generate the kernels "
" will default to the current directory ",
)
args = parser.parse_args()
main(args.output_dir)
Loading