-
Notifications
You must be signed in to change notification settings - Fork 39
Remove dropout functionality from flash attention #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Eliminates the hardcoded dropout_p=0.0 parameter from dynamic mask attention function calls in benchmark files. Since dropout was disabled (set to 0.0), removing this parameter simplifies the function calls without affecting functionality.
…h-dmattn into Support-backward
Cleans up codebase by removing specialized CUDA implementations for dropout operations, rotary positional encoding, and Philox random number generation. These components were likely moved to a different location or are no longer needed in the current architecture.
Eliminates the DROPOUT_SWITCH macro definition which was no longer needed in the codebase, simplifying the conditional compilation logic and reducing code complexity.
Simplifies the normalize_softmax_lse function by removing the Is_dropout template parameter and associated dropout scaling logic. Eliminates the rp_dropout parameter and its usage in scale calculation, streamlining the function interface and reducing complexity. Also removes the unused philox.cuh include that was likely related to dropout random number generation.
Simplifies the kernel interface by eliminating the Is_dropout template parameter and associated conditional logic throughout the forward pass implementations. Reduces template instantiation complexity and removes branching logic that was previously used to handle dropout variations for different head dimensions. Streamlines kernel dispatch by removing DROPOUT_SWITCH macros and consolidating execution paths that were previously split based on dropout configuration.
Cleans up the parameter structure by removing unused dropout probability fields, scaling factors, random state management, and rotary interleaving flag. Moves softcap field to improve struct organization and readability.
Eliminates dropout functionality across forward pass implementations to simplify the codebase and reduce compilation overhead. Removes dropout parameter handling, probability calculations, random number generation setup, and dropout-related conditional logic from both regular and variable-length attention functions. Simplifies split-KV logic by removing dropout conditional checks and enables certain optimizations that were previously gated by dropout requirements. Updates return signatures to exclude RNG state tensors that are no longer needed without dropout functionality.
Eliminates dropout parameter and related logic across all attention functions and classes. Simplifies block size calculation by removing dropout-dependent branching logic. Removes random number generator state handling and validation checks for dropout probability. Streamlines the interface by focusing on core attention computation without stochastic regularization, reducing complexity in function signatures and internal logic.
Cleans up build configuration by removing unused commented compilation flag for disabling dropout functionality
Eliminates dropout-related template parameters, includes, and implementation code throughout the attention computation functions. Simplifies the kernel interface by removing Is_dropout template parameter and associated dropout logic including RNG state management, dropout application during attention computation, and dropout-specific normalization paths. Streamlines the codebase by removing dependencies on ATen CUDA utilities and dropout/rotary header files that are no longer needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR removes all dropout functionality from the flash attention implementation to simplify the codebase and reduce complexity. The change eliminates unused components and streamlines function interfaces by removing dropout parameters and related logic.
- Removes dropout parameters from all flash attention function signatures
- Eliminates dropout-related CUDA kernels and template parameters
- Simplifies attention computation by removing dropout application logic
Reviewed Changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| setup.py | Comments out the FLASHATTENTION_DISABLE_DROPOUT compilation flag |
| flash_dmattn/flash_dmattn_interface.py | Removes dropout_p parameters and related validation logic from all attention functions |
| csrc/src/static_switch.h | Removes DROPOUT_SWITCH macro definition |
| csrc/src/softmax.h | Removes dropout template parameter and logic from normalize_softmax_lse |
| csrc/src/rotary.h | Deletes entire rotary position encoding implementation file |
| csrc/src/philox.cuh | Deletes Philox random number generator implementation |
| csrc/src/flash_fwd_launch_template.h | Removes dropout template parameters from kernel launches |
| csrc/src/flash_fwd_kernel.h | Removes dropout application logic and RNG state handling |
| csrc/src/flash.h | Removes dropout-related parameters from Flash_fwd_params struct |
| csrc/src/dropout.h | Deletes entire dropout implementation file |
| csrc/flash_api.cpp | Removes dropout parameter handling and RNG state management |
| benchmarks/*.py | Updates benchmark scripts to remove dropout_p=0.0 parameters |
| auto kernel = &flash_fwd_kernel<Kernel_traits, Is_causal, IsEvenMNConst && IsEvenKConst && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst, Is_softcap, ReturnSoftmaxConst && !Is_softcap>; | ||
| // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; | ||
| // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); | ||
| // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst)); |
Copilot
AI
Aug 8, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The printf statement is missing format specifiers. It should include "%d" for the last four variables that are being printed.
| // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst)); | |
| // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst)); |
Fix #94
Eliminate all dropout-related parameters and logic from the flash attention implementation. This cleanup simplifies the codebase and reduces complexity by removing unused components and streamlining function interfaces.