Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion benchmarks/benchmark_forward_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ def dynamic_mask_attention_cuda(
value_states, # [batch, key_len, num_kv_heads, head_dim]
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
dropout_p=0.0,
is_causal=is_causal,
scale=scaling,
softcap=0.0,
Expand Down
1 change: 0 additions & 1 deletion benchmarks/benchmark_forward_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def dynamic_mask_attention_cuda(
value_states, # [batch, key_len, num_kv_heads, head_dim]
attn_mask=attn_mask, # [batch, num_kv_heads, query_len, key_len]
attn_bias=attn_bias, # [batch, num_kv_heads, query_len, key_len]
dropout_p=0.0,
is_causal=is_causal,
scale=scaling,
softcap=0.0,
Expand Down
110 changes: 22 additions & 88 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ void set_params_fprop(
void *seqused_k,
void *p_d,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
bool is_causal,
const float softcap,
Expand Down Expand Up @@ -134,20 +133,6 @@ void set_params_fprop(
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
}

// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead of <
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
TORCH_CHECK(p_dropout < 1.f);
#ifdef FLASHATTENTION_DISABLE_DROPOUT
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
#endif

params.is_causal = is_causal;
params.is_seqlens_k_cumulative = true;

Expand Down Expand Up @@ -223,7 +208,6 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
const int max_seqlen_k,
const int max_seqlen_q,
const int head_size_rounded,
const float p_dropout,
const int num_splits,
const int num_sm,
struct c10::TensorOptions opts
Expand All @@ -239,19 +223,17 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
at::Tensor softmax_lse_accum;
at::Tensor out_accum;

if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
if (num_splits < 1) {
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
}
if (params.num_splits > 1) {
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
if (num_splits < 1) {
// We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, num_sm * 2, num_n_blocks, 128);
}
if (params.num_splits > 1) {
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
out_accum = torch::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
params.oaccum_ptr = out_accum.data_ptr();
}
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");

// Temporarily disable Split-KV, because some bugs are still being fixed.
// See: https://github.com/SmallDoges/flash-dmattn/issues/47
Expand All @@ -272,12 +254,10 @@ mha_fwd(
const at::Tensor &mask, // batch_size x num_heads_k x seqlen_q x seqlen_k
const at::Tensor &bias, // batch_size x num_heads_k x seqlen_q x seqlen_k
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const float p_dropout,
const float softmax_scale,
bool is_causal,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_
const bool return_softmax
) {

// Otherwise the kernel will be launched from cuda:0 device
Expand Down Expand Up @@ -313,14 +293,12 @@ mha_fwd(
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }

// causal=true is the same as causal=false in this case
if (seqlen_q == 1) { is_causal = false; }

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0;
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0;
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2);
Expand Down Expand Up @@ -357,12 +335,10 @@ mha_fwd(

auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time

if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}
else {
} else {
p = torch::empty({ 0 }, opts);
}

Expand All @@ -380,7 +356,6 @@ mha_fwd(
/*seqused_k=*/nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal,
softcap
Expand All @@ -390,26 +365,9 @@ mha_fwd(
at::Tensor softmax_lse_accum, out_accum;
std::tie(softmax_lse_accum, out_accum) = set_params_splitkv(
params, batch_size, num_heads, head_size, seqlen_k, seqlen_q,
head_size_rounded, p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts
head_size_rounded, /*num_splits*/ 0, get_num_sm(get_current_device()), opts
);

// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}

if (seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream);
Expand All @@ -424,7 +382,7 @@ mha_fwd(
q = q.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
}
return {out, softmax_lse, p, rng_state};
return {out, softmax_lse, p};
}

std::vector<at::Tensor>
Expand All @@ -442,13 +400,11 @@ mha_varlen_fwd(
std::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
int max_seqlen_q,
const int max_seqlen_k,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
bool is_causal,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_
const bool return_softmax
) {
// Otherwise the kernel will be launched from cuda:0 device
at::cuda::CUDAGuard device_guard{q.device()};
Expand Down Expand Up @@ -494,8 +450,6 @@ mha_varlen_fwd(
const int head_size = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);

if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
Expand All @@ -507,7 +461,7 @@ mha_varlen_fwd(

// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
// H/t Daniel Haziza
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && p_dropout == 0.f && head_size % 8 == 0;
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && head_size % 8 == 0;
const int ngroups = num_heads / num_heads_k;
if (seqlenq_ngroups_swapped) {
q = q.reshape({batch_size, num_heads_k, ngroups, head_size}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size});
Expand Down Expand Up @@ -568,12 +522,10 @@ mha_varlen_fwd(
auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
at::Tensor p;
// Only return softmax if there's dropout to reduce compilation time

if (return_softmax) {
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
p = torch::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
}
else {
} else {
p = torch::empty({ 0 }, opts);
}

Expand All @@ -597,7 +549,6 @@ mha_varlen_fwd(
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
return_softmax ? p.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal,
softcap,
Expand All @@ -621,7 +572,7 @@ mha_varlen_fwd(
set_params_splitkv(
params, batch_size, num_heads, head_size,
max_seqlen_k, max_seqlen_q, head_size_rounded,
p_dropout, /*num_splits*/ 0, get_num_sm(get_current_device()), opts
/*num_splits*/ 0, get_num_sm(get_current_device()), opts
);
}

Expand All @@ -635,23 +586,6 @@ mha_varlen_fwd(
params.leftpad_k = static_cast<int *>(leftpad_k.data_ptr());
}

// number of times random will be generated per thread, to offset philox counter in thc random
// state
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
int64_t counter_offset = params.b * params.h * 32;
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
auto rng_state = torch::empty({2}, options.dtype(torch::kInt64));
// Forward kernel will populate memory with the seed and offset.
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());

if (p_dropout > 0.0) {
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
params.philox_args = gen->philox_cuda_state(counter_offset);
}

if (max_seqlen_k > 0) {
auto stream = at::cuda::getCurrentCUDAStream().stream();
run_mha_fwd(params, stream, paged_KV);
Expand All @@ -669,7 +603,7 @@ mha_varlen_fwd(
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
}

return {out, softmax_lse, p, rng_state};
return {out, softmax_lse, p};
}

} // namespace FLASH_NAMESPACE
Expand Down
95 changes: 0 additions & 95 deletions csrc/src/dropout.h

This file was deleted.

Loading