Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

@LoserCheems LoserCheems commented Aug 8, 2025

Fix #94
Eliminate all dropout-related parameters and logic from the flash attention implementation. This cleanup simplifies the codebase and reduces complexity by removing unused components and streamlining function interfaces.

Eliminates the hardcoded dropout_p=0.0 parameter from dynamic mask attention function calls in benchmark files.

Since dropout was disabled (set to 0.0), removing this parameter simplifies the function calls without affecting functionality.
Cleans up codebase by removing specialized CUDA implementations for dropout operations, rotary positional encoding, and Philox random number generation.

These components were likely moved to a different location or are no longer needed in the current architecture.
Eliminates the DROPOUT_SWITCH macro definition which was no longer needed in the codebase, simplifying the conditional compilation logic and reducing code complexity.
Simplifies the normalize_softmax_lse function by removing the Is_dropout template parameter and associated dropout scaling logic.

Eliminates the rp_dropout parameter and its usage in scale calculation, streamlining the function interface and reducing complexity.

Also removes the unused philox.cuh include that was likely related to dropout random number generation.
Simplifies the kernel interface by eliminating the Is_dropout template parameter and associated conditional logic throughout the forward pass implementations.

Reduces template instantiation complexity and removes branching logic that was previously used to handle dropout variations for different head dimensions.

Streamlines kernel dispatch by removing DROPOUT_SWITCH macros and consolidating execution paths that were previously split based on dropout configuration.
Cleans up the parameter structure by removing unused dropout probability fields, scaling factors, random state management, and rotary interleaving flag.

Moves softcap field to improve struct organization and readability.
Eliminates dropout functionality across forward pass implementations to simplify the codebase and reduce compilation overhead.

Removes dropout parameter handling, probability calculations, random number generation setup, and dropout-related conditional logic from both regular and variable-length attention functions.

Simplifies split-KV logic by removing dropout conditional checks and enables certain optimizations that were previously gated by dropout requirements.

Updates return signatures to exclude RNG state tensors that are no longer needed without dropout functionality.
Eliminates dropout parameter and related logic across all attention functions and classes. Simplifies block size calculation by removing dropout-dependent branching logic. Removes random number generator state handling and validation checks for dropout probability.

Streamlines the interface by focusing on core attention computation without stochastic regularization, reducing complexity in function signatures and internal logic.
Cleans up build configuration by removing unused commented compilation flag for disabling dropout functionality
Eliminates dropout-related template parameters, includes, and implementation code throughout the attention computation functions.

Simplifies the kernel interface by removing Is_dropout template parameter and associated dropout logic including RNG state management, dropout application during attention computation, and dropout-specific normalization paths.

Streamlines the codebase by removing dependencies on ATen CUDA utilities and dropout/rotary header files that are no longer needed.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR removes all dropout functionality from the flash attention implementation to simplify the codebase and reduce complexity. The change eliminates unused components and streamlines function interfaces by removing dropout parameters and related logic.

  • Removes dropout parameters from all flash attention function signatures
  • Eliminates dropout-related CUDA kernels and template parameters
  • Simplifies attention computation by removing dropout application logic

Reviewed Changes

Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
setup.py Comments out the FLASHATTENTION_DISABLE_DROPOUT compilation flag
flash_dmattn/flash_dmattn_interface.py Removes dropout_p parameters and related validation logic from all attention functions
csrc/src/static_switch.h Removes DROPOUT_SWITCH macro definition
csrc/src/softmax.h Removes dropout template parameter and logic from normalize_softmax_lse
csrc/src/rotary.h Deletes entire rotary position encoding implementation file
csrc/src/philox.cuh Deletes Philox random number generator implementation
csrc/src/flash_fwd_launch_template.h Removes dropout template parameters from kernel launches
csrc/src/flash_fwd_kernel.h Removes dropout application logic and RNG state handling
csrc/src/flash.h Removes dropout-related parameters from Flash_fwd_params struct
csrc/src/dropout.h Deletes entire dropout implementation file
csrc/flash_api.cpp Removes dropout parameter handling and RNG state management
benchmarks/*.py Updates benchmark scripts to remove dropout_p=0.0 parameters

auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && !Is_softcap>;
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst));
Copy link

Copilot AI Aug 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The printf statement is missing format specifiers. It should include "%d" for the last four variables that are being printed.

Suggested change
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst));
// printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst));

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit b2a148f into main Aug 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove the lsee frequently used functions

5 participants