Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions xllm/core/distributed_runtime/worker_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,24 +136,15 @@ void WorkerService::step(ForwardInput& fwd_input,
}
}
} else {
auto int_options = torch::TensorOptions().device(torch::kCPU);
if (worker_->is_driver()) {
// construct fake output tensor
auto options =
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
auto total_prefill_seq_len = 0;
auto total_num_sequences = 0;

total_num_sequences += fwd_input.input_params.num_sequences;
total_prefill_seq_len += fwd_input.input_params.prefill_seq_len;

next_tokens =
torch::arange(-1,
-1 * (total_num_sequences - total_prefill_seq_len + 1),
-1,
options);
int32_t num_decode_seqs = fwd_input.sampling_params.sample_idxes.size(0);
next_tokens = torch::arange(
-1, -1 * (num_decode_seqs + 1), -1, int_options.dtype(torch::kInt32));
std::move(future).deferValue([](auto&&) {});
}
expert_load_data = torch::zeros({1, 1}).to(torch::kInt64).contiguous();
expert_load_data = torch::zeros({1, 1}, int_options.dtype(torch::kInt64));
}
}

Expand Down
6 changes: 2 additions & 4 deletions xllm/core/framework/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ std::map<uint32_t, uint32_t> Batch::cal_seq_exchange_index(
return index_shift;
}

RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
uint32_t end_idx,
const ModelArgs& args,
RawForwardInput Batch::prepare_forward_input(const ModelArgs& args,
ThreadPool* thread_pool) {
dp_balance_shuffle_seqs();
BatchInputBuilder builder(sequences_,
Expand All @@ -210,7 +208,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
&args,
batch_forward_type_,
thread_pool);
return builder.build_raw_forward_input(start_idx, end_idx);
return builder.build_raw_forward_input();
}

void Batch::process_sample_output(const RawForwardOutput& raw_output,
Expand Down
4 changes: 1 addition & 3 deletions xllm/core/framework/batch/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ class Batch {
const ModelArgs& args);

// Convert Batch to pb type, which will be pass to remote worker.
RawForwardInput prepare_forward_input(uint32_t start_idx,
uint32_t end_idx,
const ModelArgs& args,
RawForwardInput prepare_forward_input(const ModelArgs& args,
ThreadPool* thread_pool);

// process output
Expand Down
37 changes: 13 additions & 24 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ BatchInputBuilder::BatchInputBuilder(
mm_data_vec_(mm_data_vec),
args_(args),
thread_pool_(thread_pool),
num_sequences_(static_cast<int32_t>(sequences.size())),
num_sequences_(sequences.size()),
swap_block_transfer_infos_(swap_block_transfer_infos),
batch_id_(batch_id) {
// Reserve space for better performance
Expand All @@ -72,35 +72,31 @@ BatchInputBuilder::BatchInputBuilder(
ForwardInput BatchInputBuilder::build_forward_input(
uint32_t num_decoding_tokens,
uint32_t min_decoding_batch_size) {
process_sequences(0, static_cast<uint32_t>(num_sequences_));
process_sequences();
padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size);

return state_to_forward_input();
}

RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx,
uint32_t end_idx) {
if (!thread_pool_ ||
end_idx - start_idx < static_cast<uint32_t>(thread_pool_->size())) {
process_sequences(start_idx, end_idx);
RawForwardInput BatchInputBuilder::build_raw_forward_input() {
if (!thread_pool_ || num_sequences_ < thread_pool_->size()) {
process_sequences();
} else {
process_sequences_multithreaded(start_idx, end_idx);
process_sequences_multithreaded();
}
return state_to_raw_forward_input();
}

void BatchInputBuilder::process_sequences(uint32_t start_idx,
uint32_t end_idx) {
for (int32_t i = start_idx; i < end_idx; ++i) {
void BatchInputBuilder::process_sequences() {
for (int32_t i = 0; i < num_sequences_; ++i) {
process_single_sequence(i);
}
}

void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
uint32_t end_idx) {
void BatchInputBuilder::process_sequences_multithreaded() {
const size_t threads_num = thread_pool_->size();
const size_t sequences_per_thread =
(end_idx - start_idx + threads_num - 1) / threads_num;
(num_sequences_ + threads_num - 1) / threads_num;

BlockingCounter counter(threads_num);

Expand All @@ -117,17 +113,17 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
BuilderState& state,
std::unordered_set<int32_t>& write_block_ids) {
for (size_t i = thread_start_idx;
i < thread_end_idx && i < static_cast<size_t>(end_idx);
i < thread_end_idx && i < static_cast<size_t>(num_sequences_);
++i) {
process_single_sequence(i, &state, &write_block_ids);
}
};

// Start parallel tasks
for (size_t thread_idx = 0; thread_idx < threads_num; ++thread_idx) {
size_t thread_start_idx = start_idx + thread_idx * sequences_per_thread;
size_t thread_start_idx = thread_idx * sequences_per_thread;
size_t thread_end_idx = std::min(thread_start_idx + sequences_per_thread,
static_cast<size_t>(end_idx));
static_cast<size_t>(num_sequences_));

thread_pool_->schedule([process_sequences_range,
thread_start_idx,
Expand Down Expand Up @@ -214,7 +210,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx,
state_.new_token_slot_ids.insert(state_.new_token_slot_ids.end(),
state.new_token_slot_ids.begin(),
state.new_token_slot_ids.end());
state_.prefill_seq_len += state.prefill_seq_len;
state_.embedding_ids.insert(state_.embedding_ids.end(),
state.embedding_ids.begin(),
state.embedding_ids.end());
Expand Down Expand Up @@ -306,11 +301,6 @@ void BatchInputBuilder::process_single_sequence(
sequence, n_kv_cache_tokens, seq_len, q_seq_len, state_ptr);
}

// Track prefill sequences
if (sequence->is_chunked_prefill_stage()) {
state.prefill_seq_len++;
}

// Input for beam search kernel
if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search() &&
sequence->num_generated_tokens() > 0) {
Expand Down Expand Up @@ -658,7 +648,6 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
raw_forward_input.num_sequences = num_sequences_;
// raw_forward_input.dp_global_token_nums = ;
raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos);
raw_forward_input.prefill_seq_len = state_.prefill_seq_len;

// for flashinfer
raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr);
Expand Down
9 changes: 4 additions & 5 deletions xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ class BatchInputBuilder {
ForwardInput build_forward_input(uint32_t num_decoding_tokens,
uint32_t min_decoding_batch_size);

RawForwardInput build_raw_forward_input(uint32_t start_idx, uint32_t end_idx);
RawForwardInput build_raw_forward_input();

private:
// Core building methods
void process_sequences(uint32_t start_idx, uint32_t end_idx);
void process_sequences_multithreaded(uint32_t start_idx, uint32_t end_idx);
void process_sequences();
void process_sequences_multithreaded();
void padding_decode_batch_size(uint32_t num_decoding_tokens,
uint32_t min_decoding_batch_size);
ForwardInput state_to_forward_input();
Expand Down Expand Up @@ -100,7 +100,6 @@ class BatchInputBuilder {
// Additional data
std::vector<int32_t> embedding_ids;
std::vector<int32_t> extra_token_ids;
uint32_t prefill_seq_len = 0;
std::vector<TransferKVInfo> transfer_kv_infos;

// for continuous kvcache
Expand Down Expand Up @@ -153,7 +152,7 @@ class BatchInputBuilder {

// Configuration
bool use_mrope_ = false;
int32_t num_sequences_ = 0;
uint32_t num_sequences_ = 0;

// copy in and out cache contents
std::unordered_set<int32_t> write_block_ids_;
Expand Down
7 changes: 1 addition & 6 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ struct ModelInputParams {

params.mm_data = MMData::to(mm_data, device);
params.dp_global_token_nums = dp_global_token_nums;
params.prefill_seq_len = prefill_seq_len;
params.embedding_ids = std::move(embedding_ids);
params.extra_token_ids = std::move(extra_token_ids);
params.dp_ep_padding_data = dp_ep_padding_data;
Expand Down Expand Up @@ -151,8 +150,7 @@ struct ModelInputParams {
<< " , global_empty_kv_cache is " << global_empty_kv_cache
<< " , num_sequences is " << num_sequences
<< " , kv_max_seq_len is " << kv_max_seq_len
<< " , q_max_seq_len is " << q_max_seq_len
<< " , prefill_seq_len is " << prefill_seq_len;
<< " , q_max_seq_len is " << q_max_seq_len;
LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec;
LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec;
LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range;
Expand Down Expand Up @@ -209,9 +207,6 @@ struct ModelInputParams {
// whether the kv-cache is empty for all sequences,mainly used for dp case
bool global_empty_kv_cache = true;

// num of prefill sequence in chunked prefill case
uint32_t prefill_seq_len = 0;

// embedding ids of each sequence
std::vector<int32_t> embedding_ids;

Expand Down
2 changes: 0 additions & 2 deletions xllm/core/runtime/forward_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ struct RawForwardInput {
// chunked prefill case of speculative decoding
// extra token ids for each sequence, and -1 for last chunk
std::vector<int32_t> extra_token_ids;
// num of prefill sequence in chunked prefill case
uint32_t prefill_seq_len;
// embedding ids of each sequence
std::vector<int> embedding_ids;
// swap
Expand Down
12 changes: 4 additions & 8 deletions xllm/core/runtime/forward_shared_memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,10 @@ INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) {
total += type_size<uint64_t> +
input.swap_blocks.size() * swap_block_info_fixed_size();

total += type_size<bool> * 2 // empty_kv_cache + global_empty_kv_cache
+ type_size<int32_t> // batch_forward_type
+ type_size<uint32_t> *
3 // max_seq_len + q_max_seq_len + prefill_seq_len
+ type_size<int32_t> // num_sequences
total += type_size<bool> * 2 // empty_kv_cache + global_empty_kv_cache
+ type_size<int32_t> // batch_forward_type
+ type_size<uint32_t> * 2 // max_seq_len + q_max_seq_len
+ type_size<int32_t> // num_sequences
+ get_eplb_info_size(input.eplb_info);
// m_position
total += get_2d_vector_size(input.m_positions_vec);
Expand Down Expand Up @@ -577,7 +576,6 @@ INLINE void deserialize_raw_forward_input(
read_data(buffer, input.q_max_seq_len);
read_data(buffer, input.num_sequences);
read_eplb_info(buffer, input.eplb_info);
read_data(buffer, input.prefill_seq_len);
read_2d_vector(buffer, input.m_positions_vec);
read_mm_data(buffer, input.mm_data);
}
Expand Down Expand Up @@ -630,7 +628,6 @@ INLINE void serialize_raw_forward_input(const RawForwardInput& input,
write_data(buffer, input.q_max_seq_len);
write_data(buffer, input.num_sequences);
write_eplb_info(buffer, input.eplb_info);
write_data(buffer, input.prefill_seq_len);
write_2d_vector(buffer, input.m_positions_vec);
write_mm_data(buffer, input.mm_data);
}
Expand Down Expand Up @@ -832,7 +829,6 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,
input_params.num_sequences = raw_input.num_sequences;
input_params.kv_max_seq_len = raw_input.max_seq_len;
input_params.q_max_seq_len = raw_input.q_max_seq_len;
input_params.prefill_seq_len = raw_input.prefill_seq_len;
input_params.embedding_ids = std::move(raw_input.embedding_ids);
input_params.dp_global_token_nums = std::move(raw_input.dp_global_token_nums);

Expand Down
4 changes: 2 additions & 2 deletions xllm/core/runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,8 +876,8 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(

// build model input for every single micro batch
for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) {
batched_inputs.emplace_back(std::move(batch[dp_rank].prepare_forward_input(
0, batch[dp_rank].size(), args_, threadpool_.get())));
batched_inputs.emplace_back(std::move(
batch[dp_rank].prepare_forward_input(args_, threadpool_.get())));
dp_global_token_nums[dp_rank] =
batched_inputs[dp_rank].flatten_tokens_vec.size();
global_empty_kv_cache =
Expand Down
2 changes: 0 additions & 2 deletions xllm/core/runtime/params_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
BatchForwardType(pb_forward_input->batch_forward_type());
input_params.num_sequences = block_tables_vec.size();
assert(input_params.num_sequences == pb_forward_input->num_sequences());
input_params.prefill_seq_len = pb_forward_input->prefill_seq_len();
input_params.kv_max_seq_len = pb_forward_input->max_seq_len();
input_params.q_max_seq_len = pb_forward_input->q_max_seq_len();
input_params.kv_seq_lens = torch::tensor(seq_lens, tensor_options);
Expand Down Expand Up @@ -455,7 +454,6 @@ void forward_input_to_proto(const RawForwardInput& inputs,
*pb_forward_input->mutable_embeds()->Add() = embeds;
}

pb_forward_input->set_prefill_seq_len(inputs.prefill_seq_len);
ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_embedding_ids(),
inputs.embedding_ids);
ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_extra_token_ids(),
Expand Down
4 changes: 2 additions & 2 deletions xllm/core/runtime/vlm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ std::vector<RawForwardInput> VLMEngine::prepare_inputs(
bool global_empty_kv_cache = true;

for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) {
batched_inputs.emplace_back(std::move(batch[dp_rank].prepare_forward_input(
0, batch[dp_rank].size(), args_, threadpool_.get())));
batched_inputs.emplace_back(std::move(
batch[dp_rank].prepare_forward_input(args_, threadpool_.get())));
dp_global_token_nums[dp_rank] =
batched_inputs[dp_rank].flatten_tokens_vec.size();
global_empty_kv_cache =
Expand Down
2 changes: 1 addition & 1 deletion xllm/proto/worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ message ForwardInput {
bool global_empty_kv_cache = 21;
repeated TransferKVInfo transfer_kv_infos = 22;
repeated Embeddings embeds = 23;
uint32 prefill_seq_len = 24;
// uint32 prefill_seq_len = 24;
repeated int32 embedding_ids = 25;
repeated int32 extra_token_ids = 26;
EplbInfo eplb_info =27;
Expand Down