Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,8 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
) {

// This needs to match with run_mha_fwd_splitkv_dispatch
const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64);
const int block_n = 64;
// const int block_n = head_size <= 32 ? 128 : (head_size <= 128 ? 128 : 64);
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
// In any case we don't expect seqlen_q to be larger than 64 for inference.
Expand Down
70 changes: 29 additions & 41 deletions csrc/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,11 @@ void run_mha_bwd_hdim32(Flash_bwd_params &params, cudaStream_t stream) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 104 * 1024) { // H100 and A100
// 104KB
// 104KB, 1 CTAs in A100, 2 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 96KB
// We need to adjust no_double_buffer to save some smem, because is_v_in_regs=true will still allocate smem that may overflow
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_causal>(params, stream);
// 96KB, 2 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_causal>(params, stream);
}
}

Expand All @@ -158,17 +157,17 @@ void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
// 144KB
// In fwd, multi-CTA configurations are faster, but in bwd, their speeds are very close.
// 56KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 4 CTAs in H100.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
// 72KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
// 144KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_causal>(params, stream);
// This has a lot of register spilling
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>>(params, stream);
} else { // sm86 and sm89
// 88KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
// 72KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
}
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
}
Expand All @@ -186,11 +185,11 @@ void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
if (max_smem_per_block >= 116 * 1024) { // H100 and A100
// 116KB
// 116KB, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 80KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 4, 4, false, false, T>, Is_causal>(params, stream);
// 92KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_causal>(params, stream);
}
}

Expand All @@ -205,20 +204,12 @@ void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
// printf("max_smem_per_block = %d\n", max_smem_per_block);
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 2, 2, 2, false, false, T>>(params, stream);
if (max_smem_per_block >= 224 * 1024) { // H100
// 224KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100
// 144KB
if (max_smem_per_block >= 144 * 1024) { // H100 and A100
// 144KB, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 88KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
// 88KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_causal>(params, stream);
}
}

Expand All @@ -233,15 +224,12 @@ void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 208 * 1024) { // H100
// 208KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else if (max_smem_per_block >= 152 * 1024) { // A100
// 152KB
if (max_smem_per_block >= 136 * 1024) { // H100 and A100
// 136KB, 1 CTAs in A100, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 88KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
// 96KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_causal>(params, stream);
}
}

Expand All @@ -256,15 +244,15 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 200 * 1024) { // H100
// 200KB
if (max_smem_per_block >= 176 * 1024) { // H100
// 176KB, 1 CTAs in H100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else if (max_smem_per_block >= 132 * 1024) { // A100
// 132KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 64, 8, 4, 2, 2, false, false, T>, Is_causal>(params, stream);
} else if (max_smem_per_block >= 144 * 1024) { // A100
// 144KB, 1 CTAs in A100.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 82KB
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 32, 8, 4, 1, 2, true, false, T>, Is_causal>(params, stream);
// 96KB, 1 CTAs in sm86 and sm 89.
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, Is_causal>(params, stream);
}
}

Expand Down
70 changes: 53 additions & 17 deletions csrc/src/flash_fwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ void run_flash_splitkv_fwd(Flash_fwd_params &params, cudaStream_t stream) {
template<typename T, int Headdim, bool Is_causal>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int kBlockM = 64; // Fixed for all head dimensions
constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64);
constexpr static int kBlockN = 64; // Fixed for all head dimensions
// constexpr static int kBlockN = Headdim <= 32 ? 128 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
}

Expand All @@ -171,11 +172,18 @@ void run_mha_fwd_hdim32(Flash_fwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 176 * 1024) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
if (max_smem_per_block >= 164 * 1024) {
// 28KB, 3 CTAs in sm86 and sm 89, 5 CTAs in A100, 8 CTAs in H100.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
// 48KB, 2 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
// 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
} else {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
// 24KB, 4 CTAs in sm86 and sm 89.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
}

}

template<typename T, bool Is_causal>
Expand All @@ -190,11 +198,18 @@ void run_mha_fwd_hdim64(Flash_fwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 224 * 1024) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
} else {
if (max_smem_per_block >= 164 * 1024) { // H100 and A100
// 40KB, 2 CTAs in sm86 and sm 89, 4 CTAs in A100, 5 CTAs in H100.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
// 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_causal>(params, stream);
// 112KB, N/A in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 32KB, 3 CTAs in sm86 and sm 89.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
}

}

template<typename T, bool Is_causal>
Expand All @@ -209,9 +224,15 @@ void run_mha_fwd_hdim96(Flash_fwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 160 * 1024) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
} else {
if (max_smem_per_block >= 164 * 1024) { // H100 and A100
// 52KB, 1 CTAs in sm86 and sm 89, 3 CTAs in A100, 4 CTAs in H100.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
// 80KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 2 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
// 136KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 40KB, 2 CTAs in sm86 and sm 89.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
}
}
Expand All @@ -228,19 +249,28 @@ void run_mha_fwd_hdim128(Flash_fwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 192 * 1024) {
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
} else {
// For sm86 or sm89, 64 x 64 (48 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM.
// Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment
if (max_smem_per_block >= 164 * 1024) { // H100 and A100
// 64KB, 1 CTAs in sm86 and sm 89, 2 CTAs in A100, 3 CTAs in H100.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
// 96KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
// 160KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 48KB, 2 CTAs in sm86 and sm 89.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
}
}

template<typename T, bool Is_causal>
void run_mha_fwd_hdim192(Flash_fwd_params &params, cudaStream_t stream) {
constexpr static int Headdim = 192;
// 88KB, 1 CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
// 128KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 1 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
// 208KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_causal>(params, stream);
}

template<typename T, bool Is_causal>
Expand All @@ -255,9 +285,15 @@ void run_mha_fwd_hdim256(Flash_fwd_params &params, cudaStream_t stream) {
if (status_ != cudaSuccess) {
C10_CUDA_CHECK(status_);
}
if (max_smem_per_block >= 224 * 1024) {
if (max_smem_per_block >= 112 * 1024) { // H100 and A100
// 112KB, N/A CTAs in sm86 and sm 89, 1 CTAs in A100, 2 CTAs in H100.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
} else {
// 192KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, 1 CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_causal>(params, stream);
// 256KB, N/A CTAs in sm86 and sm 89, N/A CTAs in A100, N/A CTAs in H100.
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_causal>(params, stream);
} else { // sm86 and sm89
// 80KB, 1 CTAs in sm86 and sm 89.
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, true, true, T>, Is_causal>(params, stream);
}
}
Expand Down
Loading