diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/benchmark_forward_equivalence.py index e3e9918..fac85d2 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/benchmark_forward_equivalence.py @@ -256,7 +256,6 @@ def dynamic_mask_attention_cuda( value_states, # [batch, key_len, num_kv_heads, head_dim] attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] - dropout_p=0.0, is_causal=is_causal, scale=scaling, softcap=0.0, diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/benchmark_forward_performance.py index 759668a..0c1c042 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/benchmark_forward_performance.py @@ -265,7 +265,6 @@ def dynamic_mask_attention_cuda( value_states, # [batch, key_len, num_kv_heads, head_dim] attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len] attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len] - dropout_p=0.0, is_causal=is_causal, scale=scaling, softcap=0.0, diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index a59f841..937d02b 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -47,7 +47,6 @@ void set_params_fprop( void *seqused_k, void *p_d, void *softmax_lse_d, - float p_dropout, float softmax_scale, bool is_causal, const float softcap, @@ -134,20 +133,6 @@ void set_params_fprop( params.scale_softmax_log2 = softmax_scale * M_LOG2E; } - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f - p_dropout; - // Convert p from float to int so we don't have to convert the random uint to float to compare. - // [Minor] We want to round down since when we do the comparison we use <= instead of < - // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); - // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); - params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); - params.rp_dropout = 1.f / params.p_dropout; - params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; - TORCH_CHECK(p_dropout < 1.f); - #ifdef FLASHATTENTION_DISABLE_DROPOUT - TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); - #endif - params.is_causal = is_causal; params.is_seqlens_k_cumulative = true; @@ -223,7 +208,6 @@ std::tuple set_params_splitkv( const int max_seqlen_k, const int max_seqlen_q, const int head_size_rounded, - const float p_dropout, const int num_splits, const int num_sm, struct c10::TensorOptions opts @@ -239,19 +223,17 @@ std::tuple set_params_splitkv( at::Tensor softmax_lse_accum; at::Tensor out_accum; - if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout - if (num_splits < 1) { - // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. - params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); - } - if (params.num_splits > 1) { - softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); - out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); - params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); - params.oaccum_ptr = out_accum.data_ptr(); - } - TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); + if (num_splits < 1) { + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128); + } + if (params.num_splits > 1) { + softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); } + TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); // Temporarily disable Split-KV, because some bugs are still being fixed. // See: https://github.com/SmallDoges/flash-dmattn/issues/47 @@ -272,12 +254,10 @@ mha_fwd( const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k std::optional &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) - const float p_dropout, const float softmax_scale, bool is_causal, const float softcap, - const bool return_softmax, - std::optional gen_ + const bool return_softmax ) { // Otherwise the kernel will be launched from cuda:0 device @@ -313,14 +293,12 @@ mha_fwd( TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } - // causal=true is the same as causal=false in this case if (seqlen_q == 1) { is_causal = false; } // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0; + const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2); @@ -357,12 +335,10 @@ mha_fwd( auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); at::Tensor p; - // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } - else { + } else { p = torch::empty({ 0 }, opts); } @@ -380,7 +356,6 @@ mha_fwd( /*seqused_k=*/nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), - p_dropout, softmax_scale, is_causal, softcap @@ -390,26 +365,9 @@ mha_fwd( at::Tensor softmax_lse_accum, out_accum; std::tie(softmax_lse_accum, out_accum) = set_params_splitkv( params, batch_size, num_heads, head_size, seqlen_k, seqlen_q, - head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts + head_size_rounded, /*num_splits*/ 0, get_num_sm(get_current_device()), opts ); - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - if (seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream); @@ -424,7 +382,7 @@ mha_fwd( q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } - return {out, softmax_lse, p, rng_state}; + return {out, softmax_lse, p}; } std::vector @@ -442,13 +400,11 @@ mha_varlen_fwd( std::optional &block_table_, // batch_size x max_num_blocks_per_seq int max_seqlen_q, const int max_seqlen_k, - const float p_dropout, const float softmax_scale, const bool zero_tensors, bool is_causal, const float softcap, - const bool return_softmax, - std::optional gen_ + const bool return_softmax ) { // Otherwise the kernel will be launched from cuda:0 device at::cuda::CUDAGuard device_guard{q.device()}; @@ -494,8 +450,6 @@ mha_varlen_fwd( const int head_size = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); - if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } - const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); @@ -507,7 +461,7 @@ mha_varlen_fwd( // Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case // H/t Daniel Haziza - const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0; + const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0; const int ngroups = num_heads / num_heads_k; if (seqlenq_ngroups_swapped) { q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size}); @@ -568,12 +522,10 @@ mha_varlen_fwd( auto opts = q.options(); auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; - // Only return softmax if there's dropout to reduce compilation time + if (return_softmax) { - TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0"); p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts); - } - else { + } else { p = torch::empty({ 0 }, opts); } @@ -597,7 +549,6 @@ mha_varlen_fwd( seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr, return_softmax ? p.data_ptr() : nullptr, softmax_lse.data_ptr(), - p_dropout, softmax_scale, is_causal, softcap, @@ -621,7 +572,7 @@ mha_varlen_fwd( set_params_splitkv( params, batch_size, num_heads, head_size, max_seqlen_k, max_seqlen_q, head_size_rounded, - p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts + /*num_splits*/ 0, get_num_sm(get_current_device()), opts ); } @@ -635,23 +586,6 @@ mha_varlen_fwd( params.leftpad_k = static_cast(leftpad_k.data_ptr()); } - // number of times random will be generated per thread, to offset philox counter in thc random - // state - // We use a custom RNG that increases the offset by batch_size * nheads * 32. - int64_t counter_offset = params.b * params.h * 32; - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - // Forward kernel will populate memory with the seed and offset. - params.rng_state = reinterpret_cast(rng_state.data_ptr()); - - if (p_dropout > 0.0) { - auto gen = at::get_generator_or_default( - gen_, at::cuda::detail::getDefaultCUDAGenerator()); - // See Note [Acquire lock when using random generators] - std::lock_guard lock(gen->mutex_); - params.philox_args = gen->philox_cuda_state(counter_offset); - } - if (max_seqlen_k > 0) { auto stream = at::cuda::getCurrentCUDAStream().stream(); run_mha_fwd(params, stream, paged_KV); @@ -669,7 +603,7 @@ mha_varlen_fwd( softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } - return {out, softmax_lse, p, rng_state}; + return {out, softmax_lse, p}; } } // namespace FLASH_NAMESPACE diff --git a/csrc/src/dropout.h b/csrc/src/dropout.h deleted file mode 100644 index 9077b79..0000000 --- a/csrc/src/dropout.h +++ /dev/null @@ -1,95 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include "namespace_config.h" -#include "philox.cuh" -#include "utils.h" - -namespace FLASH_NAMESPACE { - -struct Dropout { - - const unsigned long long seed, offset; - const uint8_t p_dropout_in_uint8_t; - - __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, - const uint8_t p_dropout_in_uint8_t, - const int bid, const int hid, const int tid, const int nheads) - : seed(seed) - , offset(offset + (bid * nheads + hid) * 32 + tid % 32) - , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { - } - - template - __forceinline__ __device__ void apply_dropout(Tensor &tensor_, - int block_row_start, int block_col_start, int block_row_stride) { - // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) - Tensor tensor = make_tensor(tensor_.data(), FLASH_NAMESPACE::convert_layout_acc_dropout(tensor_.layout())); - using T = typename Engine::value_type; - auto encode_dropout = [](bool keep, T val) { - return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); - }; - static_assert(decltype(size<2>(tensor))::value % 2 == 0); - const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); - const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); - // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } - #pragma unroll - for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { - uint2 rowcol = make_uint2(block_row_start, block_col_start); - #pragma unroll - for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { - // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} - uint4 random_uint4 = FLASH_NAMESPACE::philox(seed, reinterpret_cast(rowcol), offset); - // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} - uint8_t (&rnd_8)[16] = reinterpret_cast(random_uint4); - // Special implementation for 16-bit types: we duplicate the threshold to the - // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction - // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, - // and the high 16 bits will be either 0xffff or 0x0000, depending on whether - // the random value is less than the threshold. - // We then do a bit-wise AND between the mask and the original value (in 32-bit). - // We're exploiting the fact that floating point comparison is equivalent to integer - // comparison, since we're comparing unsigned integers whose top 8-bits are zero. - if (!encode_dropout_in_sign_bit - && (std::is_same::value || std::is_same::value)) { - uint16_t rnd_16[16]; - #pragma unroll - for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } - uint32_t (&rnd_32)[8] = reinterpret_cast(rnd_16); - #pragma unroll - for (int j = 0; j < 2; j++) { - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - #pragma unroll - for (int i = 0; i < 4; i++) { - uint32_t mask; - asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); - tensor_uint32(i) &= mask; - } - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } else { - #pragma unroll - for (int j = 0; j < 2; j++) { - #pragma unroll - for (int i = 0; i < 8; i++) { - tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); - } - Tensor tensor_uint32 = recast(tensor(_, m, n * 2 + j)); - // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } - } - } - // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { - // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); - // // } - } - } - } - -}; - -} // namespace FLASH_NAMESPACE diff --git a/csrc/src/flash.h b/csrc/src/flash.h index 95d7f52..fc613dd 100644 --- a/csrc/src/flash.h +++ b/csrc/src/flash.h @@ -93,6 +93,7 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // The scaling factors for the kernel. float scale_softmax; float scale_softmax_log2; + float softcap; // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; @@ -128,23 +129,6 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par index_t block_table_batch_stride; int page_block_size; - // The dropout probability (probability of keeping an activation). - float p_dropout; - // uint32_t p_dropout_in_uint; - // uint16_t p_dropout_in_uint16_t; - uint8_t p_dropout_in_uint8_t; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - float scale_softmax_rp_dropout; - float softcap; - - // Random state. - at::PhiloxCudaState philox_args; - - // Pointer to the RNG seed (idx 0) and offset (idx 1). - uint64_t * rng_state; - bool is_bf16; bool is_causal; @@ -152,8 +136,6 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. bool is_seqlens_k_cumulative; - bool is_rotary_interleaved; - int num_splits; // For split-KV version bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. diff --git a/csrc/src/flash_fwd_kernel.h b/csrc/src/flash_fwd_kernel.h index 8699ceb..a07ae1b 100644 --- a/csrc/src/flash_fwd_kernel.h +++ b/csrc/src/flash_fwd_kernel.h @@ -5,7 +5,6 @@ #pragma once #include "namespace_config.h" -#include // For at::cuda::philox::unpack #include @@ -18,8 +17,6 @@ #include "utils.h" #include "softmax.h" #include "mask.h" -#include "dropout.h" -#include "rotary.h" namespace FLASH_NAMESPACE { @@ -47,7 +44,7 @@ __forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bid return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); } -template +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -65,17 +62,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi constexpr int kHeadDim = Kernel_traits::kHeadDim; // head_dim constexpr int kNWarps = Kernel_traits::kNWarps; - auto seed_offset = at::cuda::philox::unpack(params.philox_args); - FLASH_NAMESPACE::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t, - bidb, bidh, tidx, params.h); - - // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might - // exit early and no one saves the rng states. - if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { - params.rng_state[0] = std::get<0>(seed_offset); - params.rng_state[1] = std::get<1>(seed_offset); - } - // Check if there are any queries to process in the block const BlockInfo binfo(params, bidb); if (m_block * kBlockM >= binfo.actual_seqlen_q) return; @@ -477,20 +463,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Convert acc_s from fp32 to fp16/bf16 Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps - ); - cute::copy(rP_drop, tSgS); + cute::copy(rP, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } - if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); - } // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. @@ -574,20 +550,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Convert acc_s from fp32 to fp16/bf16 Tensor rP = FLASH_NAMESPACE::convert_type(acc_s); - int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; - int block_col_idx = n_block * (kBlockN / 32); if (Return_softmax) { - Tensor rP_drop = make_fragment_like(rP); - cute::copy(rP, rP_drop); - dropout.template apply_dropout( - rP_drop, block_row_idx, block_col_idx, kNWarps - ); - cute::copy(rP_drop, tSgS); + cute::copy(rP, tSgS); tSgS.data() = tSgS.data() + (-kBlockN); } - if (Is_dropout) { - dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); - } // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. @@ -603,7 +569,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.rp_dropout); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = FLASH_NAMESPACE::convert_type(acc_o); @@ -1198,7 +1164,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); // if (cute::thread0()) { print(lse); } Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) @@ -1276,7 +1242,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1284,15 +1250,7 @@ inline __device__ void compute_attn(const Params ¶ms) { // The block index for the head. const int bidh = blockIdx.z; - // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting - // them to have the same number of threads or have to traverse the attention matrix - // in the same order. - // In the Philox RNG, we use the offset to store the batch, head, and the lane id - // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within - // the attention matrix. This way, as long as we have the batch, head, and the location of - // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - - FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); + FLASH_NAMESPACE::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/src/flash_fwd_launch_template.h b/csrc/src/flash_fwd_launch_template.h index af992e0..89feeae 100644 --- a/csrc/src/flash_fwd_launch_template.h +++ b/csrc/src/flash_fwd_launch_template.h @@ -30,9 +30,9 @@ namespace FLASH_NAMESPACE { template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) - FLASH_NAMESPACE::compute_attn(params); + FLASH_NAMESPACE::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -51,7 +51,7 @@ DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int L FLASH_NAMESPACE::combine_attn_seqk_parallel(params); } -template +template void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const size_t smem_size = Kernel_traits::kSmemSize; // printf("smem_size = %d (includes mask memory)\n", int(smem_size)); @@ -72,9 +72,9 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If return_softmax, set IsEvenMNConst to false to reduce number of templates // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - auto kernel = &flash_fwd_kernel; + auto kernel = &flash_fwd_kernel; // auto kernel = &flash_fwd_kernel; - // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // printf("run_flash_fwd: IsEvenMNConst = %d, IsEvenKConst = %d, Is_causal = %d, ReturnSoftmaxConst = %d, int(IsEvenMNConst), int(IsEvenKConst), int(Is_causal), int(ReturnSoftmaxConst)); // auto kernel = &flash_fwd_kernel; if (smem_size >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( @@ -162,31 +162,19 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + run_flash_fwd, Is_causal>(params, stream); } template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if constexpr(!Is_dropout) { - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 128) is 27% slower for seqlen=2k - // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 128) is 27% slower for seqlen=2k + // Using block size (128 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); } template @@ -194,23 +182,21 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); } template @@ -218,51 +204,36 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x = cc_major == 8 && cc_minor > 0; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. - // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy + // For sm86 or sm89, 64 x 32 (40 KB smem) is the fastest for causal and non-causal since we get 2 CTAs per SM. + // Use block configuration (kBlockM = 64, kBlockN = 64) for better memory alignment + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_causal>(params, stream); } - }); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy } template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); } template @@ -279,15 +250,14 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { C10_CUDA_CHECK(status_); } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); - DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - // For A100, we want to run with 64 x 64 (112KB smem). - // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); + + // For A100, we want to run with 64 x 64 (112KB smem). + // For H100 we want to run with 64 x 32 (72KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } } } // namespace FLASH_NAMESPACE diff --git a/csrc/src/philox.cuh b/csrc/src/philox.cuh deleted file mode 100644 index 5205f45..0000000 --- a/csrc/src/philox.cuh +++ /dev/null @@ -1,53 +0,0 @@ -// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h -#pragma once -// Philox CUDA. - -#include "namespace_config.h" - -namespace FLASH_NAMESPACE { - -struct ull2 { - unsigned long long x; - unsigned long long y; -}; - -__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { - uint2 *res; - unsigned long long tmp; - asm ("mul.wide.u32 %0, %1, %2;\n\t" - : "=l"(tmp) - : "r"(a), "r"(b)); - res = (uint2*)(&tmp); - return *res; -} - -__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { - constexpr unsigned long kPhiloxSA = 0xD2511F53; - constexpr unsigned long kPhiloxSB = 0xCD9E8D57; - uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); - uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); - uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; - return ret; -} - -__forceinline__ __device__ uint4 philox(unsigned long long seed, - unsigned long long subsequence, - unsigned long long offset) { - constexpr unsigned long kPhilox10A = 0x9E3779B9; - constexpr unsigned long kPhilox10B = 0xBB67AE85; - uint2 key = reinterpret_cast(seed); - uint4 counter; - ull2 *tmp = reinterpret_cast(&counter); - tmp->x = offset; - tmp->y = subsequence; - #pragma unroll - for (int i = 0; i < 6; i++) { - counter = philox_single_round(counter, key); - key.x += (kPhilox10A); - key.y += (kPhilox10B); - } - uint4 output = philox_single_round(counter, key); - return output; -} - -} // namespace FLASH_NAMESPACE diff --git a/csrc/src/rotary.h b/csrc/src/rotary.h deleted file mode 100644 index dbae24c..0000000 --- a/csrc/src/rotary.h +++ /dev/null @@ -1,153 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2024, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include "namespace_config.h" -#include "utils.h" - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace FLASH_NAMESPACE { - -using namespace cute; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_interleaved(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K - static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - cute::copy(Cos(_, m, k), rCos(_, m, k)); - cute::copy(Sin(_, m, k), rSin(_, m, k)); - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS) / 2; ++i) { - float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); - float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); - S_fp32(2 * i) = real; - S_fp32(2 * i + 1) = imag; - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__forceinline__ __device__ void copy_rotary_contiguous(Tensor const &S, - Tensor &D, - Tensor const &Cos, - Tensor const &Sin, - Tensor const &identity_MN, - const int max_MN, const int min_MN, - const int dim, const int rotary_dim) { - CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); - CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K - CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M - CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K - CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA - CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); - static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 - Tensor rCos = make_fragment_like(Cos); - Tensor rSin = make_fragment_like(Sin); - Tensor rS = make_fragment_like(S); - Tensor rS_other = make_fragment_like(rS(_, 0, 0)); - #pragma unroll - for (int m = 0; m < size<1>(S); ++m) { - if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { - #pragma unroll - for (int k = 0; k < size<2>(S); ++k) { - if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { - cute::copy(S(_, m, k), rS(_, m, k)); - if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { - const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; - Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); - cute::copy(gS_other, rS_other); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } - Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); - Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); - cute::copy(gCos, rCos(_, m, k)); - cute::copy(gSin, rSin(_, m, k)); - // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } - Tensor S_fp32 = convert_type(rS(_, m, k)); - Tensor S_other_fp32 = convert_type(rS_other); - Tensor cos_fp32 = convert_type(rCos(_, m, k)); - Tensor sin_fp32 = convert_type(rSin(_, m, k)); - #pragma unroll - for (int i = 0; i < size<0>(rS); ++i) { - S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); - } - // Idk but I need to copy for the convert_type to work - Tensor S_fp32_copy = make_fragment_like(S_fp32); - cute::copy(S_fp32, S_fp32_copy); - using T = typename Engine0::value_type; - Tensor S_og_type = convert_type(S_fp32_copy); - cute::copy(S_og_type, rS(_, m, k)); - // if (cute::thread0()) { print_tensor(rS(_, m, k)); } - } - cute::copy(rS(_, m, k), D(_, m, k)); - } else if (Clear_OOB_K) { - cute::clear(D(_, m, k)); - } - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace FLASH_NAMESPACE diff --git a/csrc/src/softmax.h b/csrc/src/softmax.h index 0aeacdb..d00d056 100644 --- a/csrc/src/softmax.h +++ b/csrc/src/softmax.h @@ -11,7 +11,6 @@ #include #include "namespace_config.h" -#include "philox.cuh" #include "utils.h" namespace FLASH_NAMESPACE { @@ -199,8 +198,8 @@ struct Softmax { } }; - template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); @@ -211,7 +210,7 @@ struct Softmax { float sum = row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); - float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + float scale = inv_sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } } diff --git a/csrc/src/static_switch.h b/csrc/src/static_switch.h index 00322bd..d0504da 100644 --- a/csrc/src/static_switch.h +++ b/csrc/src/static_switch.h @@ -26,16 +26,6 @@ } \ }() -#ifdef FLASHATTENTION_DISABLE_DROPOUT - #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - constexpr static bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - }() -#else - #define DROPOUT_SWITCH BOOL_SWITCH -#endif - #ifdef FLASHATTENTION_DISABLE_UNEVEN_K #define EVENK_SWITCH(COND, CONST_NAME, ...) \ [&] { \ diff --git a/flash_dmattn/flash_dmattn_interface.py b/flash_dmattn/flash_dmattn_interface.py index f91a014..75c9814 100644 --- a/flash_dmattn/flash_dmattn_interface.py +++ b/flash_dmattn/flash_dmattn_interface.py @@ -13,7 +13,7 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _get_block_size_n(device, head_dim, is_dropout, is_causal): +def _get_block_size_n(device, head_dim, is_causal): # This should match the block sizes in the CUDA kernel assert head_dim <= 256 major, minor = torch.cuda.get_device_capability(device) @@ -23,14 +23,14 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): if head_dim <= 32: return 128 if head_dim <= 64: - return 128 if not is_dropout else 64 + return 128 elif head_dim <= 96: return 64 elif head_dim <= 128: if is_sm8x: - return 64 if (not is_dropout and is_causal) else 32 + return 64 if (is_causal) else 32 else: - return 64 if not is_dropout else 32 + return 64 elif head_dim <= 192: return 64 elif head_dim <= 224: @@ -73,28 +73,25 @@ def _flash_dmattn_forward( v: torch.Tensor, mask: torch.Tensor, bias: torch.Tensor, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, return_softmax: bool ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_dmattn_gpu.fwd( + out, softmax_lse, S_dmask = flash_dmattn_gpu.fwd( q, k, v, mask, bias, None, - dropout_p, softmax_scale, is_causal, softcap, return_softmax, - None, ) - return out, softmax_lse, S_dmask, rng_state + return out, softmax_lse, S_dmask @_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_forward") @@ -104,7 +101,6 @@ def _flash_dmattn_forward_fake( v: torch.Tensor, mask: torch.Tensor, bias: torch.Tensor, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, @@ -118,9 +114,8 @@ def _flash_dmattn_forward_fake( p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) if return_softmax: p = torch.empty((batch_size, num_heads, round_multiple(seqlen_q, 128), round_multiple(seqlen_k, 128)), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - return out, softmax_lse, p, rng_state + return out, softmax_lse, p _wrapped_flash_dmattn_forward = _flash_dmattn_forward @@ -137,7 +132,6 @@ def _flash_dmattn_varlen_forward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float = 0.0, @@ -148,7 +142,7 @@ def _flash_dmattn_varlen_forward( zero_tensors: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, softmax_lse, S_dmask, rng_state = flash_dmattn_gpu.varlen_fwd( + out, softmax_lse, S_dmask = flash_dmattn_gpu.varlen_fwd( q, k, v, @@ -162,17 +156,15 @@ def _flash_dmattn_varlen_forward( block_table, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, zero_tensors, is_causal, softcap, return_softmax, - None, ) # if out.isnan().any() or softmax_lse.isnan().any(): # breakpoint() - return out, softmax_lse, S_dmask, rng_state + return out, softmax_lse, S_dmask @_torch_register_fake_wrapper("flash_dmattn::_flash_dmattn_varlen_forward") @@ -186,7 +178,6 @@ def _flash_dmattn_varlen_forward_fake( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float = 0.0, @@ -208,8 +199,7 @@ def _flash_dmattn_varlen_forward_fake( seqlen_k_rounded = round_multiple(max_seqlen_k, 128) if return_softmax: p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) - rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) - return out, softmax_lse, p, rng_state + return out, softmax_lse, p _wrapped_flash_dmattn_varlen_forward = _flash_dmattn_varlen_forward @@ -229,12 +219,10 @@ def _flash_dmattn_backward( dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], dbias: Optional[torch.Tensor], - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: # dq, dk, dv, dbias are allocated by us so they should already be contiguous dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] @@ -257,13 +245,11 @@ def _flash_dmattn_backward( dk, dv, dbias, - dropout_p, softmax_scale, is_causal, softcap, deterministic, None, - rng_state, ) return softmax_d @@ -282,12 +268,10 @@ def _flash_dmattn_backward_fake( dk: Optional[torch.Tensor], dv: Optional[torch.Tensor], dbias: Optional[torch.Tensor], - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, ) -> torch.Tensor: dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] if dq is None: @@ -325,12 +309,10 @@ def _flash_dmattn_varlen_backward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> torch.Tensor: # dq, dk, dv are allocated by us so they should already be contiguous @@ -358,14 +340,12 @@ def _flash_dmattn_varlen_backward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, zero_tensors, is_causal, softcap, deterministic, None, - rng_state, ) # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): # breakpoint() @@ -390,12 +370,10 @@ def _flash_dmattn_varlen_backward_fake( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: float, softmax_scale: float, is_causal: bool, softcap: float, deterministic: bool, - rng_state: Optional[torch.Tensor] = None, zero_tensors: bool = False, ) -> torch.Tensor: dout, q, k, v, mask, bias, out = [maybe_contiguous(x) for x in (dout, q, k, v, mask, bias, out)] @@ -425,7 +403,6 @@ def forward( qkv: torch.Tensor, mask: Optional[torch.Tensor], bias: Optional[torch.Tensor], - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -440,10 +417,6 @@ def forward( mask = torch.ones((batch_size, num_heads, seqlen, seqlen), dtype=qkv.dtype, device=qkv.device) if bias is None: bias = torch.zeros((batch_size, num_heads, seqlen, seqlen), dtype=qkv.dtype, device=qkv.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if is_causal is None: is_causal = False if softcap is None: @@ -462,22 +435,20 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, k, v, mask, bias, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap @@ -492,7 +463,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) @@ -516,12 +487,10 @@ def backward( dqkv[:, :, 1], dqkv[:, :, 2], dbias, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension @@ -537,7 +506,6 @@ def forward( bias: Optional[torch.Tensor], cu_seqlens: torch.Tensor, max_seqlen: int, - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -553,10 +521,6 @@ def forward( mask = torch.ones((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) if bias is None: bias = torch.zeros((batch_size, num_heads, max_seqlen, max_seqlen), dtype=qkv.dtype, device=qkv.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) if is_causal is None: @@ -575,7 +539,7 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_varlen_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( q, k, v, @@ -585,17 +549,15 @@ def forward( cu_seqlens, max_seqlen, max_seqlen, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, block_table=None, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens) ctx.max_seqlen = max_seqlen ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal @@ -611,7 +573,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse, cu_seqlens = ctx.saved_tensors qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) @@ -639,12 +601,10 @@ def backward( cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension @@ -659,7 +619,6 @@ def forward( kv: torch.Tensor, mask: Optional[torch.Tensor], bias: Optional[torch.Tensor], - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -678,10 +637,6 @@ def forward( mask = torch.ones((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -700,22 +655,20 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, k, v, mask, bias, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap @@ -730,7 +683,7 @@ def backward( dout: torch.Tensor, *args: Any ): - q, k, v, mask, bias, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dq = torch.empty_like(q) @@ -755,12 +708,10 @@ def backward( dkv[:, :, 0], dkv[:, :, 1], dbias, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -780,7 +731,6 @@ def forward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -799,10 +749,6 @@ def forward( mask = torch.ones((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -821,7 +767,7 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_varlen_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( q, k, v, @@ -831,19 +777,17 @@ def forward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, block_table=None, ) if is_grad: ctx.save_for_backward( - q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k ) - ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale @@ -860,7 +804,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) dq = torch.empty_like(q) @@ -889,12 +833,10 @@ def backward( cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -911,7 +853,6 @@ def forward( v: torch.Tensor, mask: Optional[torch.Tensor], bias: Optional[torch.Tensor], - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -929,10 +870,6 @@ def forward( mask = torch.ones((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, seqlen_q, seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -950,22 +887,20 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_forward( q, k, v, mask, bias, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, ) if is_grad: - ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse, rng_state) - ctx.dropout_p = dropout_p + ctx.save_for_backward(q, k, v, mask, bias, out_padded, softmax_lse) ctx.softmax_scale = softmax_scale ctx.is_causal = is_causal ctx.softcap = softcap @@ -980,7 +915,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse = ctx.saved_tensors dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) head_size_og = dout.size(3) @@ -1001,12 +936,10 @@ def backward( dk, dv, dbias, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -1028,7 +961,6 @@ def forward( cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, - dropout_p: Optional[float], softmax_scale: Optional[float], is_causal: Optional[bool], softcap: Optional[float], @@ -1047,10 +979,6 @@ def forward( mask = torch.ones((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) if bias is None: bias = torch.zeros((batch_size, num_heads, max_seqlen_q, max_seqlen_k), dtype=q.dtype, device=q.device) - if dropout_p is None: - dropout_p = 0.0 - if dropout_p < 0.0 or dropout_p > 1.0: - raise ValueError(f"Invalid dropout_p: {dropout_p}. It should be in [0, 1].") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) if is_causal is None: @@ -1068,7 +996,7 @@ def forward( k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) - out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_dmattn_varlen_forward( + out_padded, softmax_lse, S_dmask = _wrapped_flash_dmattn_varlen_forward( q, k, v, @@ -1078,19 +1006,17 @@ def forward( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, softmax_scale, is_causal=is_causal, softcap=softcap, - return_softmax=return_softmax and dropout_p > 0, + return_softmax=return_softmax, block_table=block_table, ) if is_grad: ctx.save_for_backward( - q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + q, k, v, mask, bias, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k ) - ctx.dropout_p = dropout_p ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k ctx.softmax_scale = softmax_scale @@ -1107,7 +1033,7 @@ def backward( dout: torch.Tensor, *args: Any, ): - q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + q, k, v, mask, bias, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors dq, dk, dv, dbias = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v), torch.empty_like(bias) head_size_og = dout.size(2) @@ -1132,12 +1058,10 @@ def backward( cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, - ctx.dropout_p, ctx.softmax_scale, ctx.is_causal, ctx.softcap, ctx.deterministic, - rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension @@ -1150,14 +1074,13 @@ def flash_dmattn_qkvpacked_func( qkv: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, - dropout_p: Optional[float] = None, is_causal: Optional[bool] = None, scale: Optional[float] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. @@ -1173,7 +1096,6 @@ def flash_dmattn_qkvpacked_func( If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen, seqlen). Attention Bias to add to the attention scores. If None, no bias is applied. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1189,14 +1111,12 @@ def flash_dmattn_qkvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnQKVPackedFunc.apply( qkv, attn_mask, attn_bias, - dropout_p, scale, is_causal, softcap, @@ -1211,14 +1131,13 @@ def flash_dmattn_kvpacked_func( kv: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. @@ -1246,7 +1165,6 @@ def flash_dmattn_kvpacked_func( If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1262,15 +1180,13 @@ def flash_dmattn_kvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnKVPackedFunc.apply( q, kv, attn_mask, attn_bias, - dropout_p, scale, is_causal, softcap, @@ -1286,14 +1202,13 @@ def flash_dmattn_func( value: torch.Tensor, attn_mask: Optional[torch.Tensor] = None, attn_bias: Optional[torch.Tensor] = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head @@ -1319,7 +1234,6 @@ def flash_dmattn_func( If None, no mask is applied. attn_bias: (batch_size, nheads, seqlen_q, seqlen_k). Attention Bias to add to the attention scores. If None, no bias is applied. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1334,8 +1248,7 @@ def flash_dmattn_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnFunc.apply( query, @@ -1343,7 +1256,6 @@ def flash_dmattn_func( value, attn_mask, attn_bias, - dropout_p, scale, is_causal, softcap, @@ -1359,14 +1271,13 @@ def flash_dmattn_varlen_qkvpacked_func( attn_bias: Optional[torch.Tensor] = None, cu_seqlens: torch.Tensor = None, max_seqlen: int = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If Q, K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of Q, K, V. @@ -1382,7 +1293,6 @@ def flash_dmattn_varlen_qkvpacked_func( cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into qkv. max_seqlen: int. Maximum sequence length in the batch. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1398,8 +1308,7 @@ def flash_dmattn_varlen_qkvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnVarlenQKVPackedFunc.apply( qkv, @@ -1407,7 +1316,6 @@ def flash_dmattn_varlen_qkvpacked_func( attn_bias, cu_seqlens, max_seqlen, - dropout_p, scale, is_causal, softcap, @@ -1426,14 +1334,13 @@ def flash_dmattn_varlen_kvpacked_func( cu_seqlens_k: torch.Tensor = None, max_seqlen_q: int = None, max_seqlen_k: int = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, deterministic: Optional[bool] = None, return_attn_probs: Optional[bool] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ If K, V are already stacked into 1 tensor, this function will be faster than calling flash_dmattn_func on Q, K, V since the backward pass avoids explicit concatenation of the gradients of K, V. @@ -1467,7 +1374,6 @@ def flash_dmattn_varlen_kvpacked_func( of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1483,8 +1389,7 @@ def flash_dmattn_varlen_kvpacked_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnVarlenKVPackedFunc.apply( q, @@ -1495,7 +1400,6 @@ def flash_dmattn_varlen_kvpacked_func( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, scale, is_causal, softcap, @@ -1515,7 +1419,6 @@ def flash_dmattn_varlen_func( cu_seqlens_k: torch.Tensor = None, max_seqlen_q: int = None, max_seqlen_k: int = None, - dropout_p: Optional[float] = None, scale: Optional[float] = None, is_causal: Optional[bool] = None, softcap: Optional[float] = None, @@ -1523,7 +1426,7 @@ def flash_dmattn_varlen_func( return_attn_probs: Optional[bool] = None, block_table: Optional[torch.Tensor] = None, ): - """dropout_p should be set to 0.0 during evaluation + """ Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head @@ -1555,7 +1458,6 @@ def flash_dmattn_varlen_func( of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. - dropout_p: float. Dropout probability. is_causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). @@ -1571,8 +1473,7 @@ def flash_dmattn_varlen_func( logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). - The output of softmax (possibly with different scaling). It also encodes the dropout - pattern (negative means that location was dropped, nonnegative means it was kept). + The output of softmax (possibly with different scaling). """ return FlashDMAttnVarlenFunc.apply( query, @@ -1584,7 +1485,6 @@ def flash_dmattn_varlen_func( cu_seqlens_k, max_seqlen_q, max_seqlen_k, - dropout_p, scale, is_causal, softcap, diff --git a/setup.py b/setup.py index 3f805ad..e9b53e6 100644 --- a/setup.py +++ b/setup.py @@ -244,7 +244,6 @@ def append_nvcc_threads(nvcc_extra_args): # "--ptxas-options=-O2", # "-lineinfo", "-DFLASHATTENTION_DISABLE_BACKWARD", # Only forward pass - # "-DFLASHATTENTION_DISABLE_DROPOUT", # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", ]