From 77b5ff482899b27e0de6b3258bb363e81dd101f6 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 20:59:28 +0800 Subject: [PATCH 01/10] Removes unused dropout parameter from attention calls Eliminates the hardcoded dropout_p=0.0 parameter from dynamic mask attention function calls in benchmark files. Since dropout was disabled (set to 0.0), removing this parameter simplifies the function calls without affecting functionality. --- benchmarks/benchmark_forward_equivalence.py | 1 - benchmarks/benchmark_forward_performance.py | 1 - 2 files changed, 2 deletions(-) diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index e3e9918..fac85d2 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -256,7 +256,6 @@ def dynamic_mask_attention_cuda( value_states, # [batch, key_len, num_kv_heads, head_dim] attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] - dropout_p=0.0, is_causal=is_causal, scale=scaling, softcap=0.0, diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 759668a..0c1c042 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -265,7 +265,6 @@ def dynamic_mask_attention_cuda( value_states, # [batch, key_len, num_kv_heads, head_dim] attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] - dropout_p=0.0, is_causal=is_causal, scale=scaling, softcap=0.0, From 00ff6fe08011a9b56e4fa95d2c049a79592ab6b0 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:01:27 +0800 Subject: [PATCH 02/10] Removes dropout, rotary, and Philox RNG implementations Cleans up codebase by removing specialized CUDA implementations for dropout operations, rotary positional encoding, and Philox random number generation. These components were likely moved to a different location or are no longer needed in the current architecture. --- csrc/src/dropout.h | 95 --------------------------- csrc/src/philox.cuh | 53 --------------- csrc/src/rotary.h | 153 -------------------------------------------- 3 files changed, 301 deletions(-) delete mode 100644 csrc/src/dropout.h delete mode 100644 csrc/src/philox.cuh delete mode 100644 csrc/src/rotary.h diff --git a/csrc/src/dropout.h b/csrc/src/dropout.h deleted file mode 100644 index 9077b79..0000000 --- a/csrc/src/dropout.h +++ /dev/null @@ -1,95 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "namespace_config.h" -#include "philox.cuh" -#include "utils.h" - -namespace FLASH_NAMESPACE { - -struct Dropout { - - const unsigned long long seed, offset; - const uint8_t p_dropout_in_uint8_t; - - __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, - const uint8_t p_dropout_in_uint8_t, - const int bid, const int hid, const int tid, const int nheads) - : seed(seed) - , offset(offset + (bid * nheads + hid) * 32 + tid % 32) - , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { - } - - template - __forceinline__ __device__ void apply_dropout(Tensor &tensor_, - int block_row_start, int block_col_start, int block_row_stride) { - // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) - Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout())); - using T = typename Engine::value_type; - auto encode_dropout = [](bool keep, T val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); - }; - static_assert(decltype(size<2>(tensor))::value % 2 == 0); - const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); - const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); - // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } - #pragma unroll - for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { - uint2 rowcol = make_uint2(block_row_start, block_col_start); - #pragma unroll - for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { - // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast(rowcol), offset); - // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} - uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); - // Special implementation for 16-bit types: we duplicate the threshold to the - // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction - // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, - // and the high 16 bits will be either 0xffff or 0x0000, depending on whether - // the random value is less than the threshold. - // We then do a bit-wise AND between the mask and the original value (in 32-bit). - // We're exploiting the fact that floating point comparison is equivalent to integer - // comparison, since we're comparing unsigned integers whose top 8-bits are zero. - if (!encode_dropout_in_sign_bit - && (std::is_same::value || std::is_same::value)) { - uint16_t rnd_16[16]; - #pragma unroll - for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } - uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); - #pragma unroll - for (int j = 0; j < 2; j++) { - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t mask; - asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); - tensor_uint32(i) &= mask; - } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } else { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int i = 0; i < 8; i++) { - tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); - } - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); - // // } - } - } - } - -}; - -} // namespace FLASH_NAMESPACE diff --git a/csrc/src/philox.cuh b/csrc/src/philox.cuh deleted file mode 100644 index 5205f45..0000000 --- a/csrc/src/philox.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h -#pragma once -// Philox CUDA. - -#include "namespace_config.h" - -namespace FLASH_NAMESPACE { - -struct ull2 { - unsigned long long x; - unsigned long long y; -}; - -__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { - uint2 *res; - unsigned long long tmp; - asm ("mul.wide.u32 %0, %1, %2;\n\t" - : "=l"(tmp) - : "r"(a), "r"(b)); - res = (uint2*)(&tmp); - return *res; -} - -__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { - constexpr unsigned long kPhiloxSA = 0xD2511F53; - constexpr unsigned long kPhiloxSB = 0xCD9E8D57; - uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); - uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); - uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; - return ret; -} - -__forceinline__ __device__ uint4 philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) { - constexpr unsigned long kPhilox10A = 0x9E3779B9; - constexpr unsigned long kPhilox10B = 0xBB67AE85; - uint2 key = reinterpret_cast(seed); - uint4 counter; - ull2 *tmp = reinterpret_cast(&counter); - tmp->x = offset; - tmp->y = subsequence; - #pragma unroll - for (int i = 0; i < 6; i++) { - counter = philox_single_round(counter, key); - key.x += (kPhilox10A); - key.y += (kPhilox10B); - } - uint4 output = philox_single_round(counter, key); - return output; -} - -} // namespace FLASH_NAMESPACE diff --git a/csrc/src/rotary.h b/csrc/src/rotary.h deleted file mode 100644 index dbae24c..0000000 --- a/csrc/src/rotary.h +++ /dev/null @@ -1,153 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include "namespace_config.h" -#include "utils.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace FLASH_NAMESPACE { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K - static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - cute::copy(Cos(_, m, k), rCos(_, m, k)); - cute::copy(Sin(_, m, k), rSin(_, m, k)); - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS) / 2; ++i) { - float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); - float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); - S_fp32(2 * i) = real; - S_fp32(2 * i + 1) = imag; - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - Tensor rS_other = make_fragment_like(rS(_, 0, 0)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; - Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); - cute::copy(gS_other, rS_other); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } - Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); - Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); - cute::copy(gCos, rCos(_, m, k)); - cute::copy(gSin, rSin(_, m, k)); - // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor S_other_fp32 = convert_type(rS_other); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS); ++i) { - S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); } - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace FLASH_NAMESPACE From 027c2c8c3ca53a274f259ddc20c42ecb6a5dfccc Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:01:37 +0800 Subject: [PATCH 03/10] Removes unused dropout switch macro Eliminates the DROPOUT_SWITCH macro definition which was no longer needed in the codebase, simplifying the conditional compilation logic and reducing code complexity. --- csrc/src/static_switch.h | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/csrc/src/static_switch.h b/csrc/src/static_switch.h index 00322bd..d0504da 100644 --- a/csrc/src/static_switch.h +++ b/csrc/src/static_switch.h @@ -26,16 +26,6 @@ } \ }() -#ifdef FLASHATTENTION_DISABLE_DROPOUT - #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define DROPOUT_SWITCH BOOL_SWITCH -#endif - #ifdef FLASHATTENTION_DISABLE_UNEVEN_K #define EVENK_SWITCH(COND, CONST_NAME, ...) \ [&] { \ From c685a4907249dc7c455ed116e52ba63e31392655 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:01:52 +0800 Subject: [PATCH 04/10] Removes dropout support from softmax normalization Simplifies the normalize_softmax_lse function by removing the Is_dropout template parameter and associated dropout scaling logic. Eliminates the rp_dropout parameter and its usage in scale calculation, streamlining the function interface and reducing complexity. Also removes the unused philox.cuh include that was likely related to dropout random number generation. --- csrc/src/softmax.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/csrc/src/softmax.h b/csrc/src/softmax.h index 0aeacdb..d00d056 100644 --- a/csrc/src/softmax.h +++ b/csrc/src/softmax.h @@ -11,7 +11,6 @@ #include #include "namespace_config.h" -#include "philox.cuh" #include "utils.h" namespace FLASH_NAMESPACE { @@ -199,8 +198,8 @@ struct Softmax { } }; - template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); @@ -211,7 +210,7 @@ struct Softmax { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + float scale = inv_sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } From a5b1b496fdba2537fb3c18166d11cb2549e545b8 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:02:16 +0800 Subject: [PATCH 05/10] Removes dropout support from flash attention kernels Simplifies the kernel interface by eliminating the Is_dropout template parameter and associated conditional logic throughout the forward pass implementations. Reduces template instantiation complexity and removes branching logic that was previously used to handle dropout variations for different head dimensions. Streamlines kernel dispatch by removing DROPOUT_SWITCH macros and consolidating execution paths that were previously split based on dropout configuration. --- csrc/src/flash_fwd_launch_template.h | 144 +++++++++++---------------- 1 file changed, 57 insertions(+), 87 deletions(-) diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index af992e0..89feeae 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -30,9 +30,9 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -51,7 +51,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const size_t smem_size = Kernel_traits::kSmemSize; // printf("smem_size = %d (includes mask memory)\n", int(smem_size)); @@ -72,9 +72,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; - // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst)); // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -162,31 +162,19 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + run_flash_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if constexpr(!Is_dropout) { - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 128) is 27% slower for seqlen=2k - // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 128) is 27% slower for seqlen=2k + // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); } template @@ -194,23 +182,21 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); } template @@ -218,51 +204,36 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. - // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy + // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. + // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - }); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); } template @@ -279,15 +250,14 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For A100, we want to run with 64 x 64 (112KB smem). - // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); + + // For A100, we want to run with 64 x 64 (112KB smem). + // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } } // namespace FLASH_NAMESPACE From d7075c9a4703aada974d2daad713021382ac0da1 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:04:53 +0800 Subject: [PATCH 06/10] Removes dropout-related fields from Flash_fwd_params Cleans up the parameter structure by removing unused dropout probability fields, scaling factors, random state management, and rotary interleaving flag. Moves softcap field to improve struct organization and readability. --- csrc/src/flash.h | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 95d7f52..fc613dd 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -93,6 +93,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // The scaling factors for the kernel. float scale_softmax; float scale_softmax_log2; + float softcap; // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; @@ -128,23 +129,6 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par index_t block_table_batch_stride; int page_block_size; - // The dropout probability (probability of keeping an activation). - float p_dropout; - // uint32_t p_dropout_in_uint; - // uint16_t p_dropout_in_uint16_t; - uint8_t p_dropout_in_uint8_t; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - float scale_softmax_rp_dropout; - float softcap; - - // Random state. - at::PhiloxCudaState philox_args; - - // Pointer to the RNG seed (idx 0) and offset (idx 1). - uint64_t * rng_state; - bool is_bf16; bool is_causal; @@ -152,8 +136,6 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative; - bool is_rotary_interleaved; - int num_splits; // For split-KV version bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. From cb7997bffccee5d6df354534470b54aa7db2ade0 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:05:35 +0800 Subject: [PATCH 07/10] Removes dropout support from flash attention API Eliminates dropout functionality across forward pass implementations to simplify the codebase and reduce compilation overhead. Removes dropout parameter handling, probability calculations, random number generation setup, and dropout-related conditional logic from both regular and variable-length attention functions. Simplifies split-KV logic by removing dropout conditional checks and enables certain optimizations that were previously gated by dropout requirements. Updates return signatures to exclude RNG state tensors that are no longer needed without dropout functionality. --- csrc/flash_api.cpp | 110 +++++++++------------------------------------ 1 file changed, 22 insertions(+), 88 deletions(-) diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index a59f841..937d02b 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -47,7 +47,6 @@ void set_params_fprop( void *seqused_k, void *p_d, void *softmax_lse_d, - float p_dropout, float softmax_scale, bool is_causal, const float softcap, @@ -134,20 +133,6 @@ void set_params_fprop( params.scale_softmax_log2 = softmax_scale * M_LOG2E; } - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - // Convert p from float to int so we don't have to convert the random uint to float to compare. - // [Minor] We want to round down since when we do the comparison we use <= instead of < - // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); - // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - TORCH_CHECK(p_dropout < 1.f); - #ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); - #endif - params.is_causal = is_causal; params.is_seqlens_k_cumulative = true; @@ -223,7 +208,6 @@ std::tuple set_params_splitkv( const int max_seqlen_k, const int max_seqlen_q, const int head_size_rounded, - const float p_dropout, const int num_splits, const int num_sm, struct c10::TensorOptions opts @@ -239,19 +223,17 @@ std::tuple set_params_splitkv( at::Tensor softmax_lse_accum; at::Tensor out_accum; - if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout - if (num_splits < 1) { - // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); - } - if (params.num_splits > 1) { - softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); - } - TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + if (num_splits < 1) { + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); + } + if (params.num_splits > 1) { + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); // Temporarily disable Split-KV, because some bugs are still being fixed. // See: https://github.com/SmallDoges/flash-dmattn/issues/47 @@ -272,12 +254,10 @@ mha_fwd( const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const float p_dropout, const float softmax_scale, bool is_causal, const float softcap, - const bool return_softmax, - std::optional gen_ + const bool return_softmax ) { // Otherwise the kernel will be launched from cuda:0 device @@ -313,14 +293,12 @@ mha_fwd( TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } - // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); @@ -357,12 +335,10 @@ mha_fwd( auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor p; - // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } - else { + } else { p = torch::empty({ 0 }, opts); } @@ -380,7 +356,6 @@ mha_fwd( /*seqused_k=*/nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), - p_dropout, softmax_scale, is_causal, softcap @@ -390,26 +365,9 @@ mha_fwd( at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts + head_size_rounded, /*num_splits*/ 0, get_num_sm(get_current_device()), opts ); - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); @@ -424,7 +382,7 @@ mha_fwd( q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } - return {out, softmax_lse, p, rng_state}; + return {out, softmax_lse, p}; } std::vector @@ -442,13 +400,11 @@ mha_varlen_fwd( std::optional &block_table_, // batch_size x max_num_blocks_per_seq int max_seqlen_q, const int max_seqlen_k, - const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, const float softcap, - const bool return_softmax, - std::optional gen_ + const bool return_softmax ) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -494,8 +450,6 @@ mha_varlen_fwd( const int head_size = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } - const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); @@ -507,7 +461,7 @@ mha_varlen_fwd( // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0; + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); @@ -568,12 +522,10 @@ mha_varlen_fwd( auto opts = q.options(); auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; - // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } - else { + } else { p = torch::empty({ 0 }, opts); } @@ -597,7 +549,6 @@ mha_varlen_fwd( seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), - p_dropout, softmax_scale, is_causal, softcap, @@ -621,7 +572,7 @@ mha_varlen_fwd( set_params_splitkv( params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, - p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts + /*num_splits*/ 0, get_num_sm(get_current_device()), opts ); } @@ -635,23 +586,6 @@ mha_varlen_fwd( params.leftpad_k = static_cast(leftpad_k.data_ptr()); } - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream, paged_KV); @@ -669,7 +603,7 @@ mha_varlen_fwd( softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } - return {out, softmax_lse, p, rng_state}; + return {out, softmax_lse, p}; } } // namespace FLASH_NAMESPACE From 35ac6d810eacc8eaf41a321e0056ff1973b0ba92 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:06:07 +0800 Subject: [PATCH 08/10] Removes dropout functionality from flash attention interface Eliminates dropout parameter and related logic across all attention functions and classes. Simplifies block size calculation by removing dropout-dependent branching logic. Removes random number generator state handling and validation checks for dropout probability. Streamlines the interface by focusing on core attention computation without stochastic regularization, reducing complexity in function signatures and internal logic. --- flash_dmattn/flash_dmattn_interface.py | 192 ++++++------------------- 1 file changed, 46 insertions(+), 146 deletions(-) diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index f91a014..75c9814 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -13,7 +13,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _get_block_size_n(device, head_dim, is_dropout, is_causal): +def _get_block_size_n(device, head_dim, is_causal): # This should match the block sizes in the CUDA kernel assert head_dim <= 256 major, minor = torch.cuda.get_device_capability(device) @@ -23,14 +23,14 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): if head_dim <= 32: return 128 if head_dim <= 64: - return 128 if not is_dropout else 64 + return 128 elif head_dim <= 96: return 64 elif head_dim <= 128: if is_sm8x: - return 64 if (not is_dropout and is_causal) else 32 + return 64 if (is_causal) else 32 else: - return 64 if not is_dropout else 32 + return 64 elif head_dim <= 192: return 64 elif head_dim <= 224: @@ -73,28 +73,25 @@ def _flash_dmattn_forward( v: torch.Tensor, mask: torch.Tensor, bias: torch.Tensor, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, return_softmax: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_dmattn_gpu.fwd( + out, softmax_lse, S_dmask = flash_dmattn_gpu.fwd( q, k, v, mask, bias, None, - dropout_p, softmax_scale, is_causal, softcap, return_softmax, - None, ) - return out, softmax_lse, S_dmask, rng_state + return out, softmax_lse, S_dmask @_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_forward") @@ -104,7 +101,6 @@ def _flash_dmattn_forward_fake( v: torch.Tensor, mask: torch.Tensor, bias: torch.Tensor, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, @@ -118,9 +114,8 @@ def _flash_dmattn_forward_fake( p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) if return_softmax: p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - return out, softmax_lse, p, rng_state + return out, softmax_lse, p _wrapped_flash_dmattn_forward = _flash_dmattn_forward @@ -137,7 +132,6 @@ def _flash_dmattn_varlen_forward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float = 0.0, @@ -148,7 +142,7 @@ def _flash_dmattn_varlen_forward( zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_dmattn_gpu.varlen_fwd( + out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( q, k, v, @@ -162,17 +156,15 @@ def _flash_dmattn_varlen_forward( block_table, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, zero_tensors, is_causal, softcap, return_softmax, - None, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() - return out, softmax_lse, S_dmask, rng_state + return out, softmax_lse, S_dmask @_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_forward") @@ -186,7 +178,6 @@ def _flash_dmattn_varlen_forward_fake( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float = 0.0, @@ -208,8 +199,7 @@ def _flash_dmattn_varlen_forward_fake( seqlen_k_rounded = round_multiple(max_seqlen_k, 128) if return_softmax: p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - return out, softmax_lse, p, rng_state + return out, softmax_lse, p _wrapped_flash_dmattn_varlen_forward = _flash_dmattn_varlen_forward @@ -229,12 +219,10 @@ def _flash_dmattn_backward( dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], dbias: Optional[torch.Tensor], - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: # dq, dk, dv, dbias are allocated by us so they should already be contiguous dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] @@ -257,13 +245,11 @@ def _flash_dmattn_backward( dk, dv, dbias, - dropout_p, softmax_scale, is_causal, softcap, deterministic, None, - rng_state, ) return softmax_d @@ -282,12 +268,10 @@ def _flash_dmattn_backward_fake( dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], dbias: Optional[torch.Tensor], - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] if dq is None: @@ -325,12 +309,10 @@ def _flash_dmattn_varlen_backward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous @@ -358,14 +340,12 @@ def _flash_dmattn_varlen_backward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, zero_tensors, is_causal, softcap, deterministic, None, - rng_state, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() @@ -390,12 +370,10 @@ def _flash_dmattn_varlen_backward_fake( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> torch.Tensor: dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] @@ -425,7 +403,6 @@ def forward( qkv: torch.Tensor, mask: Optional[torch.Tensor], bias: Optional[torch.Tensor], - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -440,10 +417,6 @@ def forward( mask = torch.ones((batch_size, num_heads, seqlen, seqlen), dtype=qkv.dtype, device=qkv.device) if bias is None: bias = torch.zeros((batch_size, num_heads, seqlen, seqlen), dtype=qkv.dtype, device=qkv.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if is_causal is None: is_causal = False if softcap is None: @@ -462,22 +435,20 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, k, v, mask, bias, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap @@ -492,7 +463,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) @@ -516,12 +487,10 @@ def backward( dqkv[:, :, 1], dqkv[:, :, 2], dbias, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension @@ -537,7 +506,6 @@ def forward( bias: Optional[torch.Tensor], cu_seqlens: torch.Tensor, max_seqlen: int, - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -553,10 +521,6 @@ def forward( mask = torch.ones((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) if bias is None: bias = torch.zeros((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) if is_causal is None: @@ -575,7 +539,7 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_varlen_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( q, k, v, @@ -585,17 +549,15 @@ def forward( cu_seqlens, max_seqlen, max_seqlen, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, block_table=None, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens) ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal @@ -611,7 +573,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse, cu_seqlens = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) @@ -639,12 +601,10 @@ def backward( cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension @@ -659,7 +619,6 @@ def forward( kv: torch.Tensor, mask: Optional[torch.Tensor], bias: Optional[torch.Tensor], - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -678,10 +637,6 @@ def forward( mask = torch.ones((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -700,22 +655,20 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, k, v, mask, bias, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap @@ -730,7 +683,7 @@ def backward( dout: torch.Tensor, *args: Any ): - q, k, v, mask, bias, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dq = torch.empty_like(q) @@ -755,12 +708,10 @@ def backward( dkv[:, :, 0], dkv[:, :, 1], dbias, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -780,7 +731,6 @@ def forward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -799,10 +749,6 @@ def forward( mask = torch.ones((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -821,7 +767,7 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_varlen_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( q, k, v, @@ -831,19 +777,17 @@ def forward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, block_table=None, ) if is_grad: ctx.save_for_backward( - q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k ) - ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale @@ -860,7 +804,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dq = torch.empty_like(q) @@ -889,12 +833,10 @@ def backward( cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -911,7 +853,6 @@ def forward( v: torch.Tensor, mask: Optional[torch.Tensor], bias: Optional[torch.Tensor], - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -929,10 +870,6 @@ def forward( mask = torch.ones((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -950,22 +887,20 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, k, v, mask, bias, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap @@ -980,7 +915,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) head_size_og = dout.size(3) @@ -1001,12 +936,10 @@ def backward( dk, dv, dbias, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -1028,7 +961,6 @@ def forward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -1047,10 +979,6 @@ def forward( mask = torch.ones((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -1068,7 +996,7 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_varlen_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( q, k, v, @@ -1078,19 +1006,17 @@ def forward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, block_table=block_table, ) if is_grad: ctx.save_for_backward( - q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k ) - ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale @@ -1107,7 +1033,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) head_size_og = dout.size(2) @@ -1132,12 +1058,10 @@ def backward( cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -1150,14 +1074,13 @@ def flash_dmattn_qkvpacked_func( qkv: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, - dropout_p: Optional[float] = None, is_causal: Optional[bool] = None, scale: Optional[float] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. @@ -1173,7 +1096,6 @@ def flash_dmattn_qkvpacked_func( If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen, seqlen). Attention Bias to add to the attention scores. If None, no bias is applied. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1189,14 +1111,12 @@ def flash_dmattn_qkvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnQKVPackedFunc.apply( qkv, attn_mask, attn_bias, - dropout_p, scale, is_causal, softcap, @@ -1211,14 +1131,13 @@ def flash_dmattn_kvpacked_func( kv: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. @@ -1246,7 +1165,6 @@ def flash_dmattn_kvpacked_func( If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1262,15 +1180,13 @@ def flash_dmattn_kvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnKVPackedFunc.apply( q, kv, attn_mask, attn_bias, - dropout_p, scale, is_causal, softcap, @@ -1286,14 +1202,13 @@ def flash_dmattn_func( value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head @@ -1319,7 +1234,6 @@ def flash_dmattn_func( If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1334,8 +1248,7 @@ def flash_dmattn_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnFunc.apply( query, @@ -1343,7 +1256,6 @@ def flash_dmattn_func( value, attn_mask, attn_bias, - dropout_p, scale, is_causal, softcap, @@ -1359,14 +1271,13 @@ def flash_dmattn_varlen_qkvpacked_func( attn_bias: Optional[torch.Tensor] = None, cu_seqlens: torch.Tensor = None, max_seqlen: int = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. @@ -1382,7 +1293,6 @@ def flash_dmattn_varlen_qkvpacked_func( cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1398,8 +1308,7 @@ def flash_dmattn_varlen_qkvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnVarlenQKVPackedFunc.apply( qkv, @@ -1407,7 +1316,6 @@ def flash_dmattn_varlen_qkvpacked_func( attn_bias, cu_seqlens, max_seqlen, - dropout_p, scale, is_causal, softcap, @@ -1426,14 +1334,13 @@ def flash_dmattn_varlen_kvpacked_func( cu_seqlens_k: torch.Tensor = None, max_seqlen_q: int = None, max_seqlen_k: int = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. @@ -1467,7 +1374,6 @@ def flash_dmattn_varlen_kvpacked_func( of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1483,8 +1389,7 @@ def flash_dmattn_varlen_kvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnVarlenKVPackedFunc.apply( q, @@ -1495,7 +1400,6 @@ def flash_dmattn_varlen_kvpacked_func( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, scale, is_causal, softcap, @@ -1515,7 +1419,6 @@ def flash_dmattn_varlen_func( cu_seqlens_k: torch.Tensor = None, max_seqlen_q: int = None, max_seqlen_k: int = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, @@ -1523,7 +1426,7 @@ def flash_dmattn_varlen_func( return_attn_probs: Optional[bool] = None, block_table: Optional[torch.Tensor] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head @@ -1555,7 +1458,6 @@ def flash_dmattn_varlen_func( of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1571,8 +1473,7 @@ def flash_dmattn_varlen_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnVarlenFunc.apply( query, @@ -1584,7 +1485,6 @@ def flash_dmattn_varlen_func( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, scale, is_causal, softcap, From 1cf7fd450b942105b81b73a11f9167a8edfd7a13 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:06:40 +0800 Subject: [PATCH 09/10] Removes commented dropout disable flag Cleans up build configuration by removing unused commented compilation flag for disabling dropout functionality --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index 3f805ad..e9b53e6 100644 --- a/setup.py +++ b/setup.py @@ -244,7 +244,6 @@ def append_nvcc_threads(nvcc_extra_args): # "--ptxas-options=-O2", # "-lineinfo", "-DFLASHATTENTION_DISABLE_BACKWARD", # Only forward pass - # "-DFLASHATTENTION_DISABLE_DROPOUT", # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", ] From 9fa7885c82da7368bf1be7f8a943ca0dccfbb6e2 Mon Sep 17 00:00:00 2001 From: Loser Cheems Date: Fri, 8 Aug 2025 21:08:32 +0800 Subject: [PATCH 10/10] Removes dropout functionality from flash attention kernel Eliminates dropout-related template parameters, includes, and implementation code throughout the attention computation functions. Simplifies the kernel interface by removing Is_dropout template parameter and associated dropout logic including RNG state management, dropout application during attention computation, and dropout-specific normalization paths. Streamlines the codebase by removing dependencies on ATen CUDA utilities and dropout/rotary header files that are no longer needed. --- csrc/src/flash_fwd_kernel.h | 56 +++++-------------------------------- 1 file changed, 7 insertions(+), 49 deletions(-) diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 8699ceb..a07ae1b 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -5,7 +5,6 @@ #pragma once #include "namespace_config.h" -#include // For at::cuda::philox::unpack #include @@ -18,8 +17,6 @@ #include "utils.h" #include "softmax.h" #include "mask.h" -#include "dropout.h" -#include "rotary.h" namespace FLASH_NAMESPACE { @@ -47,7 +44,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -65,17 +62,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; // head_dim constexpr int kNWarps = Kernel_traits::kNWarps; - auto seed_offset = at::cuda::philox::unpack(params.philox_args); - FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, - bidb, bidh, tidx, params.h); - - // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might - // exit early and no one saves the rng states. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - params.rng_state[0] = std::get<0>(seed_offset); - params.rng_state[1] = std::get<1>(seed_offset); - } - // Check if there are any queries to process in the block const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -477,20 +463,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Convert acc_s from fp32 to fp16/bf16 Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps - ); - cute::copy(rP_drop, tSgS); + cute::copy(rP, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } - if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); - } // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. @@ -574,20 +550,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Convert acc_s from fp32 to fp16/bf16 Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps - ); - cute::copy(rP_drop, tSgS); + cute::copy(rP, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } - if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); - } // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. @@ -603,7 +569,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); @@ -1198,7 +1164,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1276,7 +1242,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1284,15 +1250,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // The block index for the head. const int bidh = blockIdx.z; - // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting - // them to have the same number of threads or have to traverse the attention matrix - // in the same order. - // In the Philox RNG, we use the offset to store the batch, head, and the lane id - // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within - // the attention matrix. This way, as long as we have the batch, head, and the location of - // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } ////////////////////////////////////////////////////////////////////////////////////////////////////