Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
dff25ca
Adds specialized causal attention kernel for bf16 hdim32
LoserCheems Jun 26, 2025
769a05f
Adds specialized kernel for 32-dim bfloat16 forward pass
LoserCheems Jun 26, 2025
d8f8f5c
Adds causal flash attention kernel for 32-dim heads
LoserCheems Jun 26, 2025
6083258
Adds specialized kernel for 32-dim heads with FP16
LoserCheems Jun 26, 2025
db4fc94
Adds specialized kernel for 64-dim causal attention
LoserCheems Jun 26, 2025
fb433fe
Add bfloat16 forward kernel for 64-dim heads on SM80
LoserCheems Jun 26, 2025
40fb47a
Adds specialized kernel for causal attention with 64-dim heads
LoserCheems Jun 26, 2025
4e776be
Adds FP16 forward kernel for 64-dim heads on SM80
LoserCheems Jun 26, 2025
18abb73
Adds specialized CUDA kernel for bfloat16 causal attention
LoserCheems Jun 26, 2025
abf0c5a
Adds specialized flash attention kernel for hdim96 bf16
LoserCheems Jun 26, 2025
fdf3a71
Adds FP16 causal forward kernel for 96 head dimension
LoserCheems Jun 26, 2025
e335a28
Adds Flash Attention forward kernel for 96-dim heads
LoserCheems Jun 26, 2025
c194663
Adds specialized kernel for bfloat16 causal attention
LoserCheems Jun 26, 2025
9e93777
Adds specialized kernel for bf16 hdim128 forward pass
LoserCheems Jun 26, 2025
92a3395
Adds specialized kernel for 128-dim causal attention
LoserCheems Jun 26, 2025
3fff603
Adds specialized FP16 kernel for head dimension 128
LoserCheems Jun 26, 2025
3c02549
Adds bfloat16 causal flash attention kernel for 192 head dim
LoserCheems Jun 26, 2025
dc3354d
Adds specialized CUDA kernel for bfloat16 head dimension 192
LoserCheems Jun 26, 2025
b829b5b
Adds FP16 causal flash attention kernel for 192 head dim
LoserCheems Jun 26, 2025
b0436da
Adds specialized Flash Attention kernel for 192 head dimension
LoserCheems Jun 26, 2025
c828cdb
Adds specialized CUDA kernel for bfloat16 causal attention
LoserCheems Jun 26, 2025
f9ce0bc
Adds flash attention forward kernel for head dimension 256
LoserCheems Jun 26, 2025
d969390
Adds FP16 causal flash attention kernel for 256 head dim
LoserCheems Jun 26, 2025
86d3b37
Adds specialized kernel for head dimension 256 with FP16
LoserCheems Jun 26, 2025
b069e4b
Adds specialized kernel for 32-dim BF16 causal attention
LoserCheems Jun 26, 2025
67c9c36
Adds split kernel for bfloat16 hdim32 forward pass
LoserCheems Jun 26, 2025
07777f4
Adds split kernel for FP16 causal attention
LoserCheems Jun 26, 2025
860c579
Adds auto-generated kernel for FP16 SM80 hdim32
LoserCheems Jun 26, 2025
090c8be
Adds split kernel for bfloat16 causal attention
LoserCheems Jun 26, 2025
c67b1ef
Adds split kernel for bfloat16 head dimension 64
LoserCheems Jun 26, 2025
c331e5b
Adds CUDA kernel for 64-dim causal attention
LoserCheems Jun 26, 2025
2d715f1
Adds split kernel file for 64-dim heads with FP16
LoserCheems Jun 26, 2025
5baaedb
Adds specialized kernel for 96-dim causal attention
LoserCheems Jun 26, 2025
8d63d45
Adds split kernel for bfloat16 head dimension 96
LoserCheems Jun 26, 2025
fa967e6
Adds specialized kernel for hdim96 fp16 causal attention
LoserCheems Jun 26, 2025
26b953b
Adds split kernel for head dimension 96 with FP16
LoserCheems Jun 26, 2025
2a4c14b
Adds split kernel for bf16 causal attention
LoserCheems Jun 26, 2025
d8ff35a
Adds split kernel for bf16 hdim128 on SM80
LoserCheems Jun 26, 2025
6aa467c
Adds specialized CUDA kernel for FP16 causal attention
LoserCheems Jun 26, 2025
34a6fe1
Adds split kernel for fp16 hdim128 on SM80
LoserCheems Jun 26, 2025
32943f1
Adds specialized kernel for head dimension 192 with bfloat16
LoserCheems Jun 26, 2025
3699225
Adds specialized kernel for head dimension 192 with bfloat16
LoserCheems Jun 26, 2025
fc4c893
Adds specialized kernel for head dimension 192 with FP16
LoserCheems Jun 26, 2025
7eff592
Adds split kernel for head dimension 192 with FP16
LoserCheems Jun 26, 2025
f4dcef6
Adds specialized kernel for hdim256 bf16 causal attention
LoserCheems Jun 26, 2025
3460dc7
Adds split kernel for head dimension 256 with bfloat16
LoserCheems Jun 26, 2025
b5fb04a
Adds specialized kernel for 256-dim causal attention
LoserCheems Jun 26, 2025
dae5441
Adds split kernel for head dimension 256 with FP16
LoserCheems Jun 26, 2025
9a17410
Enables fwd_split kernel generation and fixes data types
LoserCheems Jun 26, 2025
1421c1b
Changes is_causal field type from bool to str
LoserCheems Jun 26, 2025
51daf66
Changes return type from List to Generator for memory efficiency
LoserCheems Jun 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim128_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim128_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim128_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim128_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim192_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim192_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim192_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim192_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim256_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim256_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim256_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim256_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim32_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim32_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim32_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim32_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim64_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim64_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim64_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim64_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim96_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim96_bf16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim96_fp16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream);
}

} // namespace FLASH_NAMESPACE
14 changes: 14 additions & 0 deletions csrc/src/flash_fwd_hdim96_fp16_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template<>
void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params &params, cudaStream_t stream) {
run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream);
}

} // namespace FLASH_NAMESPACE
11 changes: 11 additions & 0 deletions csrc/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) 2025, Jingze Shi and Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "namespace_config.h"
#include "flash_fwd_launch_template.h"

namespace FLASH_NAMESPACE {

template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params &params, cudaStream_t stream);

} // namespace FLASH_NAMESPACE
Loading