-
Notifications
You must be signed in to change notification settings - Fork 39
Adds dropout support and softcap feature to flash attention #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Removes causal attention constraint to enable broader kernel usage Integrates dropout parameter throughout kernel dispatch functions and adds conditional logic based on dropout state Introduces softcap switching mechanism for attention score capping Adds split-KV kernel implementation with combine functionality for improved memory efficiency Optimizes kernel selection based on dropout state - uses different block configurations when dropout is enabled versus disabled Updates shared memory size calculation and removes mask-specific memory references
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Adds dropout control and a softcap feature to the Flash Attention forward kernels, removes the causal-only constraint, and optimizes kernel launches based on dropout and memory usage.
- Removes the static assertion preventing causal attention to allow broader kernel usage
- Integrates an
Is_dropoutflag and aSOFTCAP_SWITCHfor conditional softcap logic - Updates shared memory sizing, split-KV combine logic, and refactors kernel dispatch across head dimensions
Comments suppressed due to low confidence (3)
csrc/src/flash_fwd_launch_template.h:75
- The last template argument restricts
Return_softmaxto only whenIs_dropoutis true and!Is_softcap, which breaks cases where softmax output is requested without dropout. It should beReturn_softmaxConstdirectly.
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
csrc/src/flash_fwd_launch_template.h:71
- New branches introduced by
SOFTCAP_SWITCHneed unit or integration tests for both enabled and disabled softcap paths to ensure correct behavior.
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
csrc/src/flash_fwd_launch_template.h:54
- [nitpick] The signature of
run_flash_fwdnow includesIs_dropout. Please update or add function documentation to reflect the new parameter and its effect on kernel dispatch.
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
| // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; | ||
| // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); | ||
| // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; | ||
| if (smem_size >= 48 * 1024) { | ||
| C10_CUDA_CHECK(cudaFuncSetAttribute( | ||
| kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
| } | ||
| // int ctas_per_sm; | ||
| // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||
| // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); | ||
| // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); |
Copilot
AI
Jun 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Commented-out debug/template lines can clutter the code. Consider removing or consolidating these before merging.
| // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; | |
| // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); | |
| // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; | |
| if (smem_size >= 48 * 1024) { | |
| C10_CUDA_CHECK(cudaFuncSetAttribute( | |
| kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | |
| } | |
| // int ctas_per_sm; | |
| // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( | |
| // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); | |
| // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); | |
| if (smem_size >= 48 * 1024) { | |
| C10_CUDA_CHECK(cudaFuncSetAttribute( | |
| kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | |
| } |
Removes causal attention constraint to enable broader kernel usage
Integrates dropout parameter throughout kernel dispatch functions and adds conditional logic based on dropout state
Introduces softcap switching mechanism for attention score capping
Adds split-KV kernel implementation with combine functionality for improved memory efficiency
Optimizes kernel selection based on dropout state - uses different block configurations when dropout is enabled versus disabled
Updates shared memory size calculation and removes mask-specific memory references