Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Removes causal attention constraint to enable broader kernel usage

Integrates dropout parameter throughout kernel dispatch functions and adds conditional logic based on dropout state

Introduces softcap switching mechanism for attention score capping

Adds split-KV kernel implementation with combine functionality for improved memory efficiency

Optimizes kernel selection based on dropout state - uses different block configurations when dropout is enabled versus disabled

Updates shared memory size calculation and removes mask-specific memory references

Removes causal attention constraint to enable broader kernel usage

Integrates dropout parameter throughout kernel dispatch functions and adds conditional logic based on dropout state

Introduces softcap switching mechanism for attention score capping

Adds split-KV kernel implementation with combine functionality for improved memory efficiency

Optimizes kernel selection based on dropout state - uses different block configurations when dropout is enabled versus disabled

Updates shared memory size calculation and removes mask-specific memory references
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Adds dropout control and a softcap feature to the Flash Attention forward kernels, removes the causal-only constraint, and optimizes kernel launches based on dropout and memory usage.

  • Removes the static assertion preventing causal attention to allow broader kernel usage
  • Integrates an Is_dropout flag and a SOFTCAP_SWITCH for conditional softcap logic
  • Updates shared memory sizing, split-KV combine logic, and refactors kernel dispatch across head dimensions
Comments suppressed due to low confidence (3)

csrc/src/flash_fwd_launch_template.h:75

  • The last template argument restricts Return_softmax to only when Is_dropout is true and !Is_softcap, which breaks cases where softmax output is requested without dropout. It should be Return_softmaxConst directly.
                    auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;

csrc/src/flash_fwd_launch_template.h:71

  • New branches introduced by SOFTCAP_SWITCH need unit or integration tests for both enabled and disabled softcap paths to ensure correct behavior.
                SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {

csrc/src/flash_fwd_launch_template.h:54

  • [nitpick] The signature of run_flash_fwd now includes Is_dropout. Please update or add function documentation to reflect the new parameter and its effect on kernel dispatch.
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>

Comment on lines +76 to +86
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
Copy link

Copilot AI Jun 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Commented-out debug/template lines can clutter the code. Consider removing or consolidating these before merging.

Suggested change
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
// int ctas_per_sm;
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
if (smem_size >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}

Copilot uses AI. Check for mistakes.
@wubingheng111 wubingheng111 merged commit cc2aee0 into main Jun 23, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

feature New feature request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants