-
Notifications
You must be signed in to change notification settings - Fork 39
Optimizes CUDA kernel block sizes for better occupancy #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -112,10 +112,15 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||||||||||||||||
| auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>; | ||||||||||||||||||
| // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>; | ||||||||||||||||||
| // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; | ||||||||||||||||||
| // printf("Split = %d, Append_KV = %d, Is_causal = %d, IsEvenMNConst = %d, IsEvenKConst = %d, Is_softcap = %d\n", int(Split), int(Append_KV), int(Is_causal), int(IsEvenMNConst), int(IsEvenKConst), int(Is_softcap)); | ||||||||||||||||||
| if (smem_size >= 48 * 1024) { | ||||||||||||||||||
| C10_CUDA_CHECK(cudaFuncSetAttribute( | ||||||||||||||||||
| kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||||||||||||||||||
| } | ||||||||||||||||||
| // int ctas_per_sm; | ||||||||||||||||||
| // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||||||||||||||||||
| // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); | ||||||||||||||||||
| // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); | ||||||||||||||||||
| kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); | ||||||||||||||||||
| C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||||||||||||||||||
| }); | ||||||||||||||||||
|
|
@@ -155,15 +160,15 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) | |||||||||||||||||
| constexpr static int kBlockM = 64; // Fixed for all head dimensions | ||||||||||||||||||
| // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, | ||||||||||||||||||
| // and for headdim 192 with block size 64 x 128. | ||||||||||||||||||
| constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); | ||||||||||||||||||
| constexpr static int kBlockN = Headdim <= 64 ? 128 : (Headdim <= 128 ? 64 : 32); | ||||||||||||||||||
| run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| template<typename T, bool Is_causal> | ||||||||||||||||||
| void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { | ||||||||||||||||||
| constexpr static int Headdim = 32; | ||||||||||||||||||
| DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| }); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -173,13 +178,15 @@ void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||||||||||||||||
| DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||||||||||||||||||
| if constexpr(!Is_dropout) { | ||||||||||||||||||
| // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower | ||||||||||||||||||
| // Using block size (64 x 256) is 27% slower for seqlen=2k | ||||||||||||||||||
| // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // Using block size (64 x 128) is 27% slower for seqlen=2k | ||||||||||||||||||
| // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } else { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
|
|
@@ -196,7 +203,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||||||||||||||||
| // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), | ||||||||||||||||||
| if (is_sm8x) { | ||||||||||||||||||
| if constexpr(!Is_causal) { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } else { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
@@ -218,16 +225,16 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||||||||||||||||
| bool is_sm8x = cc_major == 8 && cc_minor > 0; | ||||||||||||||||||
| DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||||||||||||||||||
| if constexpr(!Is_dropout) { | ||||||||||||||||||
| // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), | ||||||||||||||||||
| // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. | ||||||||||||||||||
| // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. | ||||||||||||||||||
| // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment | ||||||||||||||||||
| if (is_sm8x) { | ||||||||||||||||||
| if constexpr(!Is_causal) { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } else { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } | ||||||||||||||||||
| } else { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
|
|
@@ -238,7 +245,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||||||||||||||||
| // 1st ones are good for H100, A100 | ||||||||||||||||||
| // 2nd one is good for A6000 bc we get slightly better occupancy | ||||||||||||||||||
| } else { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
|
|
@@ -251,9 +258,9 @@ void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||||||||||||||||
| constexpr static int Headdim = 192; | ||||||||||||||||||
| DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||||||||||||||||||
| if constexpr(!Is_dropout) { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } else { | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 32, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| } | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
| // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); | ||||||||||||||||||
|
|
@@ -278,17 +285,13 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||||||||||||||||
| } | ||||||||||||||||||
| // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); | ||||||||||||||||||
| DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { | ||||||||||||||||||
| // For A100, we want to run with 128 x 64 (128KB smem). | ||||||||||||||||||
| // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. | ||||||||||||||||||
| // For A100, we want to run with 64 x 64 (112KB smem). | ||||||||||||||||||
| // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. | ||||||||||||||||||
| if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { | ||||||||||||||||||
|
||||||||||||||||||
| if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { | |
| // Constants for shared memory calculations | |
| constexpr static int SMEM_BASE = 128; // Base shared memory size | |
| constexpr static int SMEM_MULTIPLIER = 64; // Multiplier for shared memory size | |
| constexpr static int SMEM_FACTOR = 2; // Factor for additional shared memory | |
| if (max_smem_per_block >= SMEM_FACTOR * Headdim * (SMEM_BASE + SMEM_FACTOR * SMEM_MULTIPLIER) && | |
| max_smem_per_sm < 4 * Headdim * (SMEM_MULTIPLIER + SMEM_FACTOR * SMEM_MULTIPLIER)) { |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Consider removing or wrapping this large block of commented-out debug code in a dedicated debug macro to reduce noise in the main dispatch path.