diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index b44a185..c783d5d 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -259,9 +259,9 @@ std::tuple set_params_splitkv( // See: https://github.com/SmallDoges/flash-dmattn/issues/47 // Regardless of how it is set externally, always set num_splits back to 1. // This is to avoid the extra memory overhead of Split-KV. - // params.num_splits = 1; - // softmax_lse_accum.reset(); - // out_accum.reset(); + params.num_splits = 1; + softmax_lse_accum.reset(); + out_accum.reset(); return std::make_tuple(softmax_lse_accum, out_accum); } @@ -288,11 +288,10 @@ mha_fwd( auto [cc_major, cc_minor] = get_compute_capability(get_current_device()); bool is_sm8x_min = cc_major >= 8; - TORCH_CHECK(is_sm8x_min, "FlashAttention only supports Ampere GPUs or newer."); + TORCH_CHECK(is_sm8x_min, "FlashDynamicMaskAttention only supports Ampere GPUs or newer."); auto q_dtype = q.dtype(); - TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, - "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, "FlashDynamicMaskAttention only support fp16 and bf16 data type"); TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); TORCH_CHECK(attn_mask.dtype() == q_dtype, "attn_mask must have the same dtype as inputs"); @@ -313,7 +312,7 @@ mha_fwd( const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); TORCH_CHECK(batch_size > 0, "batch size must be positive"); - TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size <= 256, "FlashDynamicMaskAttention forward only supports head dimension at most 256"); TORCH_CHECK(head_size % 8 == 0, "query, key, value, and out_ must have a head_size that is a multiple of 8"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");