diff --git a/docs/en/features/multi_streams.md b/docs/en/features/multi_streams.md index e8197bcf..7cb07a0e 100644 --- a/docs/en/features/multi_streams.md +++ b/docs/en/features/multi_streams.md @@ -13,9 +13,9 @@ This overlap of computation and communication effectively hides the communicatio ## Usage -xLLM provides the gflags parameter enable_comp_comm_overlap, which defaults to false. To enable this feature, set it to true in xLLM’s service startup script, as: +xLLM provides the gflags parameter `enable_multi_stream_parallel`, which defaults to false. To enable this feature, set it to true in xLLM’s service startup script, as: ```shell ---enable_comp_comm_overlap=true +--enable_multi_stream_parallel=true ``` diff --git a/docs/zh/features/multi_streams.md b/docs/zh/features/multi_streams.md index 9b39a843..385449d0 100644 --- a/docs/zh/features/multi_streams.md +++ b/docs/zh/features/multi_streams.md @@ -11,9 +11,9 @@ xLLM在模型图层支持了多流并行功能,将输入的batch拆分成2个m ## 使用方式 -xLLM中提供了gflags参数`enable_comp_comm_overlap`,默认false,如需开启在xLLM的服务启动脚本中设置为true即可,示例如下: +xLLM中提供了gflags参数`enable_multi_stream_parallel`,默认false,如需开启在xLLM的服务启动脚本中设置为true即可,示例如下: ```shell ---enable_comp_comm_overlap=true +--enable_multi_stream_parallel=true ``` diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index fab6529e..b8f4a96d 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -315,3 +315,11 @@ DEFINE_string(store_master_server_entry, DEFINE_string(store_metadata_connstring, "", "The address of the kv cache store metadata service."); + +// --- for computation communication parallel --- + +DEFINE_bool( + enable_multi_stream_parallel, + false, + "Whether to enable computation communication parallel by two streams " + "and two micro batches in prefill stage."); diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 6d6964fc..d518a18d 100755 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -159,6 +159,8 @@ DECLARE_string(store_master_server_entry); DECLARE_string(store_metadata_connstring); +DECLARE_bool(enable_multi_stream_parallel); + DECLARE_bool(enable_profile_step_time); DECLARE_bool(enable_profile_token_budget); diff --git a/xllm/core/common/options.cpp b/xllm/core/common/options.cpp old mode 100755 new mode 100644 index 2e840480..6b9f0ab6 --- a/xllm/core/common/options.cpp +++ b/xllm/core/common/options.cpp @@ -53,7 +53,8 @@ std::string Options::to_string() const { << ", enable_kvcache_store: " << enable_kvcache_store() << ", store_protocol: " << store_protocol() << ", store_master_server_entry: " << store_master_server_entry() - << ", store_metadata_connstring: " << store_metadata_connstring(); + << ", store_metadata_connstring: " << store_metadata_connstring() + << ", enable_multi_stream_parallel: " << enable_multi_stream_parallel(); ss << "]"; return ss.str(); } diff --git a/xllm/core/common/options.h b/xllm/core/common/options.h index 2a36368b..e99f50b7 100644 --- a/xllm/core/common/options.h +++ b/xllm/core/common/options.h @@ -141,6 +141,8 @@ class Options { PROPERTY(std::string, store_metadata_connstring) = ""; + PROPERTY(bool, enable_multi_stream_parallel) = false; + PROPERTY(bool, enable_profile_step_time) = false; PROPERTY(bool, enable_profile_token_budget) = false; diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index 320d206b..97039c5e 100644 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -75,13 +75,16 @@ class Batch { // process the accepted output embedding void process_embedding_output(const torch::Tensor& embedding); - // split the whole batch into several micro batches - std::vector split(const size_t num_micro_batches); - const std::vector& get_allowed_max_tokens() const { return allowed_max_tokens_; } + void set_batch_prefill_status(const bool all_seqs_in_prefill) { + all_seqs_in_prefill_ = all_seqs_in_prefill; + } + + bool get_batch_prefill_status() const { return all_seqs_in_prefill_; } + private: bool update_sequence_state(Sequence* seq, bool enable_schedule_overlap); @@ -102,6 +105,9 @@ class Batch { // mm_data in the batch std::vector mm_data_vec_; + + // all sequences in this batch are in prefill stage + bool all_seqs_in_prefill_ = true; }; } // namespace xllm diff --git a/xllm/core/framework/batch/batch_factory.cpp b/xllm/core/framework/batch/batch_factory.cpp index 2e242766..16485165 100644 --- a/xllm/core/framework/batch/batch_factory.cpp +++ b/xllm/core/framework/batch/batch_factory.cpp @@ -43,11 +43,10 @@ std::vector BatchFactory::create_batches( // if dp enabled, each sequence is required to // dispatch to the same rank in the whole lifetime - if (sequence->dp_rank() >= 0) { - batches[sequence->dp_rank()].add(sequence, token_budget); - } else { - batches[i % dp_size_].add(sequence, token_budget); - sequence->set_dp_rank(i % dp_size_); + batches[sequence->dp_rank()].add(sequence, token_budget); + if (sequence->stage() == SequenceStage::DECODE && + sequence->kv_state().kv_cache_tokens_num() > 0) { + batches[sequence->dp_rank()].set_batch_prefill_status(false); } } diff --git a/xllm/core/framework/sampling/CMakeLists.txt b/xllm/core/framework/sampling/CMakeLists.txt index 0e4ff2f0..20e4c480 100644 --- a/xllm/core/framework/sampling/CMakeLists.txt +++ b/xllm/core/framework/sampling/CMakeLists.txt @@ -22,10 +22,11 @@ cc_library( cc_test( NAME - rejection_sampler_test + sampler_test SRCS rejection_sampler_test.cpp rejection_sampler.cpp + sampling_params_test.cpp DEPS absl::strings GTest::gtest_main @@ -33,12 +34,12 @@ cc_test( :sampler glog::glog ) -target_link_libraries(rejection_sampler_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto leveldb::leveldb ZLIB::ZLIB protobuf::libprotobuf) -target_link_libraries(rejection_sampler_test +target_link_libraries(sampler_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto leveldb::leveldb ZLIB::ZLIB protobuf::libprotobuf) +target_link_libraries(sampler_test PUBLIC Python::Python $<$:ascendcl> $<$:hccl> $<$:c_sec> $<$:nnopbase>) -add_dependencies(rejection_sampler_test brpc-static) \ No newline at end of file +add_dependencies(sampler_test brpc-static) \ No newline at end of file diff --git a/xllm/core/framework/sampling/sampling_params.cpp b/xllm/core/framework/sampling/sampling_params.cpp index d8162d96..7f2f576b 100644 --- a/xllm/core/framework/sampling/sampling_params.cpp +++ b/xllm/core/framework/sampling/sampling_params.cpp @@ -139,4 +139,42 @@ void SamplingParameters::init( this->is_embeddings = is_embeddings; } +void SamplingParameters::concat(const SamplingParameters& param) { + // selected_token_idxes and sample_idxes are accumulated variable across + // all sequences in the batch, so the offset of first + // SamplingParameters is added to the second SamplingParameters + this->selected_token_idxes = + safe_concat(this->selected_token_idxes, + (param.selected_token_idxes.defined() + ? (param.selected_token_idxes + + this->selected_token_idxes[-1] + torch::tensor(1)) + : param.selected_token_idxes), + 0); + this->sample_idxes = safe_concat( + this->sample_idxes, + (param.sample_idxes.defined() + ? (param.sample_idxes + this->sample_idxes[-1] + torch::tensor(1)) + : param.sample_idxes), + 0); + this->frequency_penalties = + safe_concat(this->frequency_penalties, param.frequency_penalties, 0); + this->repetition_penalties = + safe_concat(this->repetition_penalties, param.repetition_penalties, 0); + this->temperatures = safe_concat(this->temperatures, param.temperatures, 0); + this->top_p = safe_concat(this->top_p, param.top_p, 0); + this->top_k = safe_concat(this->top_k, param.top_k, 0); + this->unique_token_ids = + safe_concat(this->unique_token_ids, param.unique_token_ids, 0); + this->unique_token_counts = + safe_concat(this->unique_token_counts, param.unique_token_counts, 0); + this->unique_token_ids_lens = + safe_concat(this->unique_token_ids_lens, param.unique_token_ids_lens, 0); + this->do_sample = safe_concat(this->do_sample, param.do_sample, 0); + this->logprobs = this->logprobs || param.logprobs; + this->is_embeddings = this->is_embeddings || param.is_embeddings; + this->max_top_logprobs = + std::max(this->max_top_logprobs, param.max_top_logprobs); + return; +} + } // namespace xllm diff --git a/xllm/core/framework/sampling/sampling_params.h b/xllm/core/framework/sampling/sampling_params.h index 30374007..f900a5c5 100644 --- a/xllm/core/framework/sampling/sampling_params.h +++ b/xllm/core/framework/sampling/sampling_params.h @@ -78,6 +78,9 @@ struct SamplingParameters { return params; } + // concat two SamplingParameters into one + void concat(const SamplingParameters& param); + // selected tokens are tokens for sampling the next token, // including the generated tokens and the last prompt token // IntTensor diff --git a/xllm/core/framework/sampling/sampling_params_test.cpp b/xllm/core/framework/sampling/sampling_params_test.cpp new file mode 100644 index 00000000..9bbaad6f --- /dev/null +++ b/xllm/core/framework/sampling/sampling_params_test.cpp @@ -0,0 +1,118 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "sampling_params.h" + +#include +#include +#include + +namespace xllm { + +TEST(SamplingParamsTest, NormalConcat) { + // construct sampling_parameters_1 + RequestSamplingParam request_1, request_2; + std::vector selected_token_idxes_1{11, 23}; + std::vector sample_idxes_1{0, 1}; + std::vector> unique_token_ids_vec_1{ + std::vector{ + 151645, 100022, 104202, 104167, 198, 77091, 872, 220, 151644}, + std::vector{ + 151645, 100022, 104202, 104167, 198, 77091, 872, 220, 151644}}; + std::vector> unique_token_counts_vec_1{ + std::vector{1, 1, 1, 1, 3, 1, 1, 1, 2}, + std::vector{1, 1, 1, 1, 3, 1, 1, 1, 2}}; + std::vector unique_token_lens_vec_1{9, 9}; + + SamplingParameters sampling_parameters_1; + sampling_parameters_1.init( + std::vector{&request_1, &request_2}, + selected_token_idxes_1, + sample_idxes_1, + unique_token_ids_vec_1, + unique_token_counts_vec_1, + unique_token_lens_vec_1); + + // construct sampling_parameters_2 + RequestSamplingParam request_3, request_4; + std::vector selected_token_idxes_2{13, 28}; + std::vector sample_idxes_2{0, 1}; + std::vector> unique_token_ids_vec_2{ + std::vector{151645, + 119414, + 100287, + 26288, + 101239, + 198, + 77091, + 106055, + 872, + 220, + 151644}, + std::vector{0, + 62112, + 9370, + 107425, + 151645, + 99489, + 106309, + 198, + 77091, + 71618, + 872, + 220, + 151644}}; + std::vector> unique_token_counts_vec_2{ + std::vector{1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 2}, + std::vector{0, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 2}}; + std::vector unique_token_lens_vec_2{11, 12}; + + SamplingParameters sampling_parameters_2; + sampling_parameters_2.init( + std::vector{&request_3, &request_4}, + selected_token_idxes_2, + sample_idxes_2, + unique_token_ids_vec_2, + unique_token_counts_vec_2, + unique_token_lens_vec_2); + + // construct expected output + torch::Tensor result_selected_token_idxes = torch::tensor({11, 23, 37, 52}); + torch::Tensor result_sample_idxes = torch::tensor({0, 1, 2, 3}); + + // execute concat + sampling_parameters_1.concat(sampling_parameters_2); + + // check results + EXPECT_TRUE(torch::equal(sampling_parameters_1.selected_token_idxes, + result_selected_token_idxes)); + EXPECT_TRUE( + torch::equal(sampling_parameters_1.sample_idxes, result_sample_idxes)); +} + +TEST(SamplingParamsTest, AbnormalConcat) { + // construct both of default sampling_parameters + SamplingParameters sampling_parameters_1, sampling_parameters_2; + + // execute concat + sampling_parameters_1.concat(sampling_parameters_2); + + // check results + EXPECT_FALSE(sampling_parameters_1.selected_token_idxes.defined()); + EXPECT_FALSE(sampling_parameters_1.sample_idxes.defined()); +} + +} // namespace xllm diff --git a/xllm/core/util/tensor_helper.h b/xllm/core/util/tensor_helper.h index 276eaf04..714810c9 100644 --- a/xllm/core/util/tensor_helper.h +++ b/xllm/core/util/tensor_helper.h @@ -119,4 +119,16 @@ inline bool file_exists(const std::string& path) { return file.good(); } +inline torch::Tensor safe_concat(const torch::Tensor& t1, + const torch::Tensor& t2, + const uint32_t dim) { + if (t1.defined() && t2.defined()) { + return torch::cat({t1, t2}, dim); + } else if (!t1.defined()) { + return t2; + } else { + return t1; + } +} + } // namespace xllm \ No newline at end of file diff --git a/xllm/pybind/args.py b/xllm/pybind/args.py index 85926a09..f1869555 100644 --- a/xllm/pybind/args.py +++ b/xllm/pybind/args.py @@ -35,6 +35,7 @@ def __init__(self): self.parser.add_argument('--enable_disagg_pd', action='store_true', help='Enable disaggregated prefill and decode execution.') self.parser.add_argument('--enable_schedule_overlap', action='store_true', help='Whether to enable schedule overlap.') self.parser.add_argument('--kv_cache_transfer_mode', type=str, default='PUSH', help='The mode of kv cache transfer(e.g. PUSH, PULL).') + self.parser.add_argument('--enable_multi_stream_parallel', action='store_true', help='Whether to enable computation communication overlap.') def parse_args(self): return self.parser.parse_args() diff --git a/xllm/xllm.cpp b/xllm/xllm.cpp index 328692e8..c9486355 100755 --- a/xllm/xllm.cpp +++ b/xllm/xllm.cpp @@ -150,6 +150,7 @@ int run() { .store_protocol(FLAGS_store_protocol) .store_master_server_entry(FLAGS_store_master_server_entry) .store_metadata_connstring(FLAGS_store_metadata_connstring) + .enable_multi_stream_parallel(FLAGS_enable_multi_stream_parallel) .enable_profile_step_time(FLAGS_enable_profile_step_time) .enable_profile_token_budget(FLAGS_enable_profile_token_budget) .enable_latency_aware_schedule(FLAGS_enable_latency_aware_schedule)