Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions csrc/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ std::tuple<at::Tensor, at::Tensor> set_params_splitkv(
// See: https://github.com/SmallDoges/flash-dmattn/issues/47
// Regardless of how it is set externally, always set num_splits back to 1.
// This is to avoid the extra memory overhead of Split-KV.
// params.num_splits = 1;
// softmax_lse_accum.reset();
// out_accum.reset();
params.num_splits = 1;
softmax_lse_accum.reset();
out_accum.reset();

return std::make_tuple(softmax_lse_accum, out_accum);
}
Expand All @@ -288,11 +288,10 @@ mha_fwd(

auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
bool is_sm8x_min = cc_major >= 8;
TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer.");
TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer.");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
Copy link

Copilot AI Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Grammar error: "only support" should be "only supports" to maintain subject-verb agreement.

Suggested change
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type");
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only supports fp16 and bf16 data type");

Copilot uses AI. Check for mistakes.
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs");
Expand All @@ -313,7 +312,7 @@ mha_fwd(
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256");
TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256");
TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

Expand Down