diff --git a/csrc/flash_api.cpp b/csrc/flash_api.cpp index 2e04840..1d948ea 100644 --- a/csrc/flash_api.cpp +++ b/csrc/flash_api.cpp @@ -255,6 +255,14 @@ std::tuple set_params_splitkv( TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported"); } + // Temporarily disable Split-KV, because some bugs are still being fixed. + // See: https://github.com/SmallDoges/flash-dmattn/issues/47 + // Regardless of how it is set externally, always set num_splits back to 1. + // This is to avoid the extra memory overhead of Split-KV. + params.num_splits = 1; + softmax_lse_accum.reset(); + out_accum.reset(); + return std::make_tuple(softmax_lse_accum, out_accum); }