Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/en/features/multi_streams.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ This overlap of computation and communication effectively hides the communicatio

## Usage

xLLM provides the gflags parameter enable_comp_comm_overlap, which defaults to false. To enable this feature, set it to true in xLLM’s service startup script, as:
xLLM provides the gflags parameter `enable_multi_stream_parallel`, which defaults to false. To enable this feature, set it to true in xLLM’s service startup script, as:
```shell
--enable_comp_comm_overlap=true
--enable_multi_stream_parallel=true
```


Expand Down
4 changes: 2 additions & 2 deletions docs/zh/features/multi_streams.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ xLLM在模型图层支持了多流并行功能,将输入的batch拆分成2个m

## 使用方式

xLLM中提供了gflags参数`enable_comp_comm_overlap`,默认false,如需开启在xLLM的服务启动脚本中设置为true即可,示例如下:
xLLM中提供了gflags参数`enable_multi_stream_parallel`,默认false,如需开启在xLLM的服务启动脚本中设置为true即可,示例如下:
```shell
--enable_comp_comm_overlap=true
--enable_multi_stream_parallel=true
```


Expand Down
8 changes: 8 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,11 @@ DEFINE_string(store_master_server_entry,
DEFINE_string(store_metadata_connstring,
"",
"The address of the kv cache store metadata service.");

// --- for computation communication parallel ---

DEFINE_bool(
enable_multi_stream_parallel,
false,
"Whether to enable computation communication parallel by two streams "
"and two micro batches in prefill stage.");
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ DECLARE_string(store_master_server_entry);

DECLARE_string(store_metadata_connstring);

DECLARE_bool(enable_multi_stream_parallel);

DECLARE_bool(enable_profile_step_time);

DECLARE_bool(enable_profile_token_budget);
Expand Down
3 changes: 2 additions & 1 deletion xllm/core/common/options.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ std::string Options::to_string() const {
<< ", enable_kvcache_store: " << enable_kvcache_store()
<< ", store_protocol: " << store_protocol()
<< ", store_master_server_entry: " << store_master_server_entry()
<< ", store_metadata_connstring: " << store_metadata_connstring();
<< ", store_metadata_connstring: " << store_metadata_connstring()
<< ", enable_multi_stream_parallel: " << enable_multi_stream_parallel();
ss << "]";
return ss.str();
}
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ class Options {

PROPERTY(std::string, store_metadata_connstring) = "";

PROPERTY(bool, enable_multi_stream_parallel) = false;

PROPERTY(bool, enable_profile_step_time) = false;

PROPERTY(bool, enable_profile_token_budget) = false;
Expand Down
12 changes: 9 additions & 3 deletions xllm/core/framework/batch/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,16 @@ class Batch {
// process the accepted output embedding
void process_embedding_output(const torch::Tensor& embedding);

// split the whole batch into several micro batches
std::vector<Batch> split(const size_t num_micro_batches);

const std::vector<uint32_t>& get_allowed_max_tokens() const {
return allowed_max_tokens_;
}

void set_batch_prefill_status(const bool all_seqs_in_prefill) {
all_seqs_in_prefill_ = all_seqs_in_prefill;
}

bool get_batch_prefill_status() const { return all_seqs_in_prefill_; }

private:
bool update_sequence_state(Sequence* seq, bool enable_schedule_overlap);

Expand All @@ -102,6 +105,9 @@ class Batch {

// mm_data in the batch
std::vector<MMData> mm_data_vec_;

// all sequences in this batch are in prefill stage
bool all_seqs_in_prefill_ = true;
};

} // namespace xllm
9 changes: 4 additions & 5 deletions xllm/core/framework/batch/batch_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ std::vector<Batch> BatchFactory::create_batches(

// if dp enabled, each sequence is required to
// dispatch to the same rank in the whole lifetime
if (sequence->dp_rank() >= 0) {
batches[sequence->dp_rank()].add(sequence, token_budget);
} else {
batches[i % dp_size_].add(sequence, token_budget);
sequence->set_dp_rank(i % dp_size_);
batches[sequence->dp_rank()].add(sequence, token_budget);
if (sequence->stage() == SequenceStage::DECODE &&
sequence->kv_state().kv_cache_tokens_num() > 0) {
batches[sequence->dp_rank()].set_batch_prefill_status(false);
}
}

Expand Down
9 changes: 5 additions & 4 deletions xllm/core/framework/sampling/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,24 @@ cc_library(

cc_test(
NAME
rejection_sampler_test
sampler_test
SRCS
rejection_sampler_test.cpp
rejection_sampler.cpp
sampling_params_test.cpp
DEPS
absl::strings
GTest::gtest_main
:flags
:sampler
glog::glog
)
target_link_libraries(rejection_sampler_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto leveldb::leveldb ZLIB::ZLIB protobuf::libprotobuf)
target_link_libraries(rejection_sampler_test
target_link_libraries(sampler_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto leveldb::leveldb ZLIB::ZLIB protobuf::libprotobuf)
target_link_libraries(sampler_test
PUBLIC
Python::Python
$<$<BOOL:${USE_NPU}>:ascendcl>
$<$<BOOL:${USE_NPU}>:hccl>
$<$<BOOL:${USE_NPU}>:c_sec>
$<$<BOOL:${USE_NPU}>:nnopbase>)
add_dependencies(rejection_sampler_test brpc-static)
add_dependencies(sampler_test brpc-static)
38 changes: 38 additions & 0 deletions xllm/core/framework/sampling/sampling_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,42 @@ void SamplingParameters::init(
this->is_embeddings = is_embeddings;
}

void SamplingParameters::concat(const SamplingParameters& param) {
// selected_token_idxes and sample_idxes are accumulated variable across
// all sequences in the batch, so the offset of first
// SamplingParameters is added to the second SamplingParameters
this->selected_token_idxes =
safe_concat(this->selected_token_idxes,
(param.selected_token_idxes.defined()
? (param.selected_token_idxes +
this->selected_token_idxes[-1] + torch::tensor(1))
: param.selected_token_idxes),
0);
this->sample_idxes = safe_concat(
this->sample_idxes,
(param.sample_idxes.defined()
? (param.sample_idxes + this->sample_idxes[-1] + torch::tensor(1))
: param.sample_idxes),
0);
this->frequency_penalties =
safe_concat(this->frequency_penalties, param.frequency_penalties, 0);
this->repetition_penalties =
safe_concat(this->repetition_penalties, param.repetition_penalties, 0);
this->temperatures = safe_concat(this->temperatures, param.temperatures, 0);
this->top_p = safe_concat(this->top_p, param.top_p, 0);
this->top_k = safe_concat(this->top_k, param.top_k, 0);
this->unique_token_ids =
safe_concat(this->unique_token_ids, param.unique_token_ids, 0);
this->unique_token_counts =
safe_concat(this->unique_token_counts, param.unique_token_counts, 0);
this->unique_token_ids_lens =
safe_concat(this->unique_token_ids_lens, param.unique_token_ids_lens, 0);
this->do_sample = safe_concat(this->do_sample, param.do_sample, 0);
this->logprobs = this->logprobs || param.logprobs;
this->is_embeddings = this->is_embeddings || param.is_embeddings;
this->max_top_logprobs =
std::max(this->max_top_logprobs, param.max_top_logprobs);
return;
}

} // namespace xllm
3 changes: 3 additions & 0 deletions xllm/core/framework/sampling/sampling_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ struct SamplingParameters {
return params;
}

// concat two SamplingParameters into one
void concat(const SamplingParameters& param);

// selected tokens are tokens for sampling the next token,
// including the generated tokens and the last prompt token
// IntTensor
Expand Down
118 changes: 118 additions & 0 deletions xllm/core/framework/sampling/sampling_params_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "sampling_params.h"

#include <glog/logging.h>
#include <gtest/gtest.h>
#include <torch/torch.h>

namespace xllm {

TEST(SamplingParamsTest, NormalConcat) {
// construct sampling_parameters_1
RequestSamplingParam request_1, request_2;
std::vector<int32_t> selected_token_idxes_1{11, 23};
std::vector<int32_t> sample_idxes_1{0, 1};
std::vector<std::vector<int64_t>> unique_token_ids_vec_1{
std::vector<int64_t>{
151645, 100022, 104202, 104167, 198, 77091, 872, 220, 151644},
std::vector<int64_t>{
151645, 100022, 104202, 104167, 198, 77091, 872, 220, 151644}};
std::vector<std::vector<int32_t>> unique_token_counts_vec_1{
std::vector<int32_t>{1, 1, 1, 1, 3, 1, 1, 1, 2},
std::vector<int32_t>{1, 1, 1, 1, 3, 1, 1, 1, 2}};
std::vector<int32_t> unique_token_lens_vec_1{9, 9};

SamplingParameters sampling_parameters_1;
sampling_parameters_1.init(
std::vector<const RequestSamplingParam*>{&request_1, &request_2},
selected_token_idxes_1,
sample_idxes_1,
unique_token_ids_vec_1,
unique_token_counts_vec_1,
unique_token_lens_vec_1);

// construct sampling_parameters_2
RequestSamplingParam request_3, request_4;
std::vector<int32_t> selected_token_idxes_2{13, 28};
std::vector<int32_t> sample_idxes_2{0, 1};
std::vector<std::vector<int64_t>> unique_token_ids_vec_2{
std::vector<int64_t>{151645,
119414,
100287,
26288,
101239,
198,
77091,
106055,
872,
220,
151644},
std::vector<int64_t>{0,
62112,
9370,
107425,
151645,
99489,
106309,
198,
77091,
71618,
872,
220,
151644}};
std::vector<std::vector<int32_t>> unique_token_counts_vec_2{
std::vector<int32_t>{1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 2},
std::vector<int32_t>{0, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 2}};
std::vector<int32_t> unique_token_lens_vec_2{11, 12};

SamplingParameters sampling_parameters_2;
sampling_parameters_2.init(
std::vector<const RequestSamplingParam*>{&request_3, &request_4},
selected_token_idxes_2,
sample_idxes_2,
unique_token_ids_vec_2,
unique_token_counts_vec_2,
unique_token_lens_vec_2);

// construct expected output
torch::Tensor result_selected_token_idxes = torch::tensor({11, 23, 37, 52});
torch::Tensor result_sample_idxes = torch::tensor({0, 1, 2, 3});

// execute concat
sampling_parameters_1.concat(sampling_parameters_2);

// check results
EXPECT_TRUE(torch::equal(sampling_parameters_1.selected_token_idxes,
result_selected_token_idxes));
EXPECT_TRUE(
torch::equal(sampling_parameters_1.sample_idxes, result_sample_idxes));
}

TEST(SamplingParamsTest, AbnormalConcat) {
// construct both of default sampling_parameters
SamplingParameters sampling_parameters_1, sampling_parameters_2;

// execute concat
sampling_parameters_1.concat(sampling_parameters_2);

// check results
EXPECT_FALSE(sampling_parameters_1.selected_token_idxes.defined());
EXPECT_FALSE(sampling_parameters_1.sample_idxes.defined());
}

} // namespace xllm
12 changes: 12 additions & 0 deletions xllm/core/util/tensor_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,16 @@ inline bool file_exists(const std::string& path) {
return file.good();
}

inline torch::Tensor safe_concat(const torch::Tensor& t1,
const torch::Tensor& t2,
const uint32_t dim) {
if (t1.defined() && t2.defined()) {
return torch::cat({t1, t2}, dim);
} else if (!t1.defined()) {
return t2;
} else {
return t1;
}
}

} // namespace xllm
1 change: 1 addition & 0 deletions xllm/pybind/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self):
self.parser.add_argument('--enable_disagg_pd', action='store_true', help='Enable disaggregated prefill and decode execution.')
self.parser.add_argument('--enable_schedule_overlap', action='store_true', help='Whether to enable schedule overlap.')
self.parser.add_argument('--kv_cache_transfer_mode', type=str, default='PUSH', help='The mode of kv cache transfer(e.g. PUSH, PULL).')
self.parser.add_argument('--enable_multi_stream_parallel', action='store_true', help='Whether to enable computation communication overlap.')

def parse_args(self):
return self.parser.parse_args()
1 change: 1 addition & 0 deletions xllm/xllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ int run() {
.store_protocol(FLAGS_store_protocol)
.store_master_server_entry(FLAGS_store_master_server_entry)
.store_metadata_connstring(FLAGS_store_metadata_connstring)
.enable_multi_stream_parallel(FLAGS_enable_multi_stream_parallel)
.enable_profile_step_time(FLAGS_enable_profile_step_time)
.enable_profile_token_budget(FLAGS_enable_profile_token_budget)
.enable_latency_aware_schedule(FLAGS_enable_latency_aware_schedule)
Expand Down