From 6d6965553e77e187675a19107a2e80e23e623f91 Mon Sep 17 00:00:00 2001 From: liangzhiwei20 Date: Wed, 26 Nov 2025 14:50:28 +0800 Subject: [PATCH] refactor: remove some useless code of batch input builder. --- .../distributed_runtime/worker_service.cpp | 19 +++------- xllm/core/framework/batch/batch.cpp | 6 +-- xllm/core/framework/batch/batch.h | 4 +- .../framework/batch/batch_input_builder.cpp | 37 +++++++------------ .../framework/batch/batch_input_builder.h | 9 ++--- .../core/framework/model/model_input_params.h | 7 +--- xllm/core/runtime/forward_params.h | 2 - .../runtime/forward_shared_memory_manager.cpp | 12 ++---- xllm/core/runtime/llm_engine.cpp | 4 +- xllm/core/runtime/params_utils.cpp | 2 - xllm/core/runtime/vlm_engine.cpp | 4 +- xllm/proto/worker.proto | 2 +- 12 files changed, 35 insertions(+), 73 deletions(-) diff --git a/xllm/core/distributed_runtime/worker_service.cpp b/xllm/core/distributed_runtime/worker_service.cpp index ca8079b4..7fe9acac 100644 --- a/xllm/core/distributed_runtime/worker_service.cpp +++ b/xllm/core/distributed_runtime/worker_service.cpp @@ -136,24 +136,15 @@ void WorkerService::step(ForwardInput& fwd_input, } } } else { + auto int_options = torch::TensorOptions().device(torch::kCPU); if (worker_->is_driver()) { // construct fake output tensor - auto options = - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU); - auto total_prefill_seq_len = 0; - auto total_num_sequences = 0; - - total_num_sequences += fwd_input.input_params.num_sequences; - total_prefill_seq_len += fwd_input.input_params.prefill_seq_len; - - next_tokens = - torch::arange(-1, - -1 * (total_num_sequences - total_prefill_seq_len + 1), - -1, - options); + int32_t num_decode_seqs = fwd_input.sampling_params.sample_idxes.size(0); + next_tokens = torch::arange( + -1, -1 * (num_decode_seqs + 1), -1, int_options.dtype(torch::kInt32)); std::move(future).deferValue([](auto&&) {}); } - expert_load_data = torch::zeros({1, 1}).to(torch::kInt64).contiguous(); + expert_load_data = torch::zeros({1, 1}, int_options.dtype(torch::kInt64)); } } diff --git a/xllm/core/framework/batch/batch.cpp b/xllm/core/framework/batch/batch.cpp index 579dcf8f..98a2fab5 100755 --- a/xllm/core/framework/batch/batch.cpp +++ b/xllm/core/framework/batch/batch.cpp @@ -196,9 +196,7 @@ std::map Batch::cal_seq_exchange_index( return index_shift; } -RawForwardInput Batch::prepare_forward_input(uint32_t start_idx, - uint32_t end_idx, - const ModelArgs& args, +RawForwardInput Batch::prepare_forward_input(const ModelArgs& args, ThreadPool* thread_pool) { dp_balance_shuffle_seqs(); BatchInputBuilder builder(sequences_, @@ -210,7 +208,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx, &args, batch_forward_type_, thread_pool); - return builder.build_raw_forward_input(start_idx, end_idx); + return builder.build_raw_forward_input(); } void Batch::process_sample_output(const RawForwardOutput& raw_output, diff --git a/xllm/core/framework/batch/batch.h b/xllm/core/framework/batch/batch.h index 4bd0e8de..79833273 100755 --- a/xllm/core/framework/batch/batch.h +++ b/xllm/core/framework/batch/batch.h @@ -85,9 +85,7 @@ class Batch { const ModelArgs& args); // Convert Batch to pb type, which will be pass to remote worker. - RawForwardInput prepare_forward_input(uint32_t start_idx, - uint32_t end_idx, - const ModelArgs& args, + RawForwardInput prepare_forward_input(const ModelArgs& args, ThreadPool* thread_pool); // process output diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 0d331259..37a55bbf 100755 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -53,7 +53,7 @@ BatchInputBuilder::BatchInputBuilder( mm_data_vec_(mm_data_vec), args_(args), thread_pool_(thread_pool), - num_sequences_(static_cast(sequences.size())), + num_sequences_(sequences.size()), swap_block_transfer_infos_(swap_block_transfer_infos), batch_id_(batch_id) { // Reserve space for better performance @@ -72,35 +72,31 @@ BatchInputBuilder::BatchInputBuilder( ForwardInput BatchInputBuilder::build_forward_input( uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size) { - process_sequences(0, static_cast(num_sequences_)); + process_sequences(); padding_decode_batch_size(num_decoding_tokens, min_decoding_batch_size); return state_to_forward_input(); } -RawForwardInput BatchInputBuilder::build_raw_forward_input(uint32_t start_idx, - uint32_t end_idx) { - if (!thread_pool_ || - end_idx - start_idx < static_cast(thread_pool_->size())) { - process_sequences(start_idx, end_idx); +RawForwardInput BatchInputBuilder::build_raw_forward_input() { + if (!thread_pool_ || num_sequences_ < thread_pool_->size()) { + process_sequences(); } else { - process_sequences_multithreaded(start_idx, end_idx); + process_sequences_multithreaded(); } return state_to_raw_forward_input(); } -void BatchInputBuilder::process_sequences(uint32_t start_idx, - uint32_t end_idx) { - for (int32_t i = start_idx; i < end_idx; ++i) { +void BatchInputBuilder::process_sequences() { + for (int32_t i = 0; i < num_sequences_; ++i) { process_single_sequence(i); } } -void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, - uint32_t end_idx) { +void BatchInputBuilder::process_sequences_multithreaded() { const size_t threads_num = thread_pool_->size(); const size_t sequences_per_thread = - (end_idx - start_idx + threads_num - 1) / threads_num; + (num_sequences_ + threads_num - 1) / threads_num; BlockingCounter counter(threads_num); @@ -117,7 +113,7 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, BuilderState& state, std::unordered_set& write_block_ids) { for (size_t i = thread_start_idx; - i < thread_end_idx && i < static_cast(end_idx); + i < thread_end_idx && i < static_cast(num_sequences_); ++i) { process_single_sequence(i, &state, &write_block_ids); } @@ -125,9 +121,9 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, // Start parallel tasks for (size_t thread_idx = 0; thread_idx < threads_num; ++thread_idx) { - size_t thread_start_idx = start_idx + thread_idx * sequences_per_thread; + size_t thread_start_idx = thread_idx * sequences_per_thread; size_t thread_end_idx = std::min(thread_start_idx + sequences_per_thread, - static_cast(end_idx)); + static_cast(num_sequences_)); thread_pool_->schedule([process_sequences_range, thread_start_idx, @@ -214,7 +210,6 @@ void BatchInputBuilder::process_sequences_multithreaded(uint32_t start_idx, state_.new_token_slot_ids.insert(state_.new_token_slot_ids.end(), state.new_token_slot_ids.begin(), state.new_token_slot_ids.end()); - state_.prefill_seq_len += state.prefill_seq_len; state_.embedding_ids.insert(state_.embedding_ids.end(), state.embedding_ids.begin(), state.embedding_ids.end()); @@ -306,11 +301,6 @@ void BatchInputBuilder::process_single_sequence( sequence, n_kv_cache_tokens, seq_len, q_seq_len, state_ptr); } - // Track prefill sequences - if (sequence->is_chunked_prefill_stage()) { - state.prefill_seq_len++; - } - // Input for beam search kernel if (FLAGS_enable_beam_search_kernel && sequence->check_beam_search() && sequence->num_generated_tokens() > 0) { @@ -658,7 +648,6 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { raw_forward_input.num_sequences = num_sequences_; // raw_forward_input.dp_global_token_nums = ; raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos); - raw_forward_input.prefill_seq_len = state_.prefill_seq_len; // for flashinfer raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr); diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index fdcb960f..8f18d6ac 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -47,12 +47,12 @@ class BatchInputBuilder { ForwardInput build_forward_input(uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size); - RawForwardInput build_raw_forward_input(uint32_t start_idx, uint32_t end_idx); + RawForwardInput build_raw_forward_input(); private: // Core building methods - void process_sequences(uint32_t start_idx, uint32_t end_idx); - void process_sequences_multithreaded(uint32_t start_idx, uint32_t end_idx); + void process_sequences(); + void process_sequences_multithreaded(); void padding_decode_batch_size(uint32_t num_decoding_tokens, uint32_t min_decoding_batch_size); ForwardInput state_to_forward_input(); @@ -100,7 +100,6 @@ class BatchInputBuilder { // Additional data std::vector embedding_ids; std::vector extra_token_ids; - uint32_t prefill_seq_len = 0; std::vector transfer_kv_infos; // for continuous kvcache @@ -153,7 +152,7 @@ class BatchInputBuilder { // Configuration bool use_mrope_ = false; - int32_t num_sequences_ = 0; + uint32_t num_sequences_ = 0; // copy in and out cache contents std::unordered_set write_block_ids_; diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 933ddcc4..857f05f1 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -110,7 +110,6 @@ struct ModelInputParams { params.mm_data = MMData::to(mm_data, device); params.dp_global_token_nums = dp_global_token_nums; - params.prefill_seq_len = prefill_seq_len; params.embedding_ids = std::move(embedding_ids); params.extra_token_ids = std::move(extra_token_ids); params.dp_ep_padding_data = dp_ep_padding_data; @@ -151,8 +150,7 @@ struct ModelInputParams { << " , global_empty_kv_cache is " << global_empty_kv_cache << " , num_sequences is " << num_sequences << " , kv_max_seq_len is " << kv_max_seq_len - << " , q_max_seq_len is " << q_max_seq_len - << " , prefill_seq_len is " << prefill_seq_len; + << " , q_max_seq_len is " << q_max_seq_len; LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec; LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec; LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range; @@ -209,9 +207,6 @@ struct ModelInputParams { // whether the kv-cache is empty for all sequences,mainly used for dp case bool global_empty_kv_cache = true; - // num of prefill sequence in chunked prefill case - uint32_t prefill_seq_len = 0; - // embedding ids of each sequence std::vector embedding_ids; diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 3bc28d50..18927f69 100755 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -165,8 +165,6 @@ struct RawForwardInput { // chunked prefill case of speculative decoding // extra token ids for each sequence, and -1 for last chunk std::vector extra_token_ids; - // num of prefill sequence in chunked prefill case - uint32_t prefill_seq_len; // embedding ids of each sequence std::vector embedding_ids; // swap diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 1f3b87ed..ed531a35 100755 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -149,11 +149,10 @@ INLINE size_t calculate_raw_forward_input_size(const RawForwardInput& input) { total += type_size + input.swap_blocks.size() * swap_block_info_fixed_size(); - total += type_size * 2 // empty_kv_cache + global_empty_kv_cache - + type_size // batch_forward_type - + type_size * - 3 // max_seq_len + q_max_seq_len + prefill_seq_len - + type_size // num_sequences + total += type_size * 2 // empty_kv_cache + global_empty_kv_cache + + type_size // batch_forward_type + + type_size * 2 // max_seq_len + q_max_seq_len + + type_size // num_sequences + get_eplb_info_size(input.eplb_info); // m_position total += get_2d_vector_size(input.m_positions_vec); @@ -577,7 +576,6 @@ INLINE void deserialize_raw_forward_input( read_data(buffer, input.q_max_seq_len); read_data(buffer, input.num_sequences); read_eplb_info(buffer, input.eplb_info); - read_data(buffer, input.prefill_seq_len); read_2d_vector(buffer, input.m_positions_vec); read_mm_data(buffer, input.mm_data); } @@ -630,7 +628,6 @@ INLINE void serialize_raw_forward_input(const RawForwardInput& input, write_data(buffer, input.q_max_seq_len); write_data(buffer, input.num_sequences); write_eplb_info(buffer, input.eplb_info); - write_data(buffer, input.prefill_seq_len); write_2d_vector(buffer, input.m_positions_vec); write_mm_data(buffer, input.mm_data); } @@ -832,7 +829,6 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.num_sequences = raw_input.num_sequences; input_params.kv_max_seq_len = raw_input.max_seq_len; input_params.q_max_seq_len = raw_input.q_max_seq_len; - input_params.prefill_seq_len = raw_input.prefill_seq_len; input_params.embedding_ids = std::move(raw_input.embedding_ids); input_params.dp_global_token_nums = std::move(raw_input.dp_global_token_nums); diff --git a/xllm/core/runtime/llm_engine.cpp b/xllm/core/runtime/llm_engine.cpp index 6cba8483..0b104ba5 100644 --- a/xllm/core/runtime/llm_engine.cpp +++ b/xllm/core/runtime/llm_engine.cpp @@ -876,8 +876,8 @@ std::vector LLMEngine::prepare_inputs( // build model input for every single micro batch for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - batched_inputs.emplace_back(std::move(batch[dp_rank].prepare_forward_input( - 0, batch[dp_rank].size(), args_, threadpool_.get()))); + batched_inputs.emplace_back(std::move( + batch[dp_rank].prepare_forward_input(args_, threadpool_.get()))); dp_global_token_nums[dp_rank] = batched_inputs[dp_rank].flatten_tokens_vec.size(); global_empty_kv_cache = diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index b6709d21..077b725c 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -191,7 +191,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, BatchForwardType(pb_forward_input->batch_forward_type()); input_params.num_sequences = block_tables_vec.size(); assert(input_params.num_sequences == pb_forward_input->num_sequences()); - input_params.prefill_seq_len = pb_forward_input->prefill_seq_len(); input_params.kv_max_seq_len = pb_forward_input->max_seq_len(); input_params.q_max_seq_len = pb_forward_input->q_max_seq_len(); input_params.kv_seq_lens = torch::tensor(seq_lens, tensor_options); @@ -455,7 +454,6 @@ void forward_input_to_proto(const RawForwardInput& inputs, *pb_forward_input->mutable_embeds()->Add() = embeds; } - pb_forward_input->set_prefill_seq_len(inputs.prefill_seq_len); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_embedding_ids(), inputs.embedding_ids); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_extra_token_ids(), diff --git a/xllm/core/runtime/vlm_engine.cpp b/xllm/core/runtime/vlm_engine.cpp index 6227a6ee..41351bff 100644 --- a/xllm/core/runtime/vlm_engine.cpp +++ b/xllm/core/runtime/vlm_engine.cpp @@ -430,8 +430,8 @@ std::vector VLMEngine::prepare_inputs( bool global_empty_kv_cache = true; for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) { - batched_inputs.emplace_back(std::move(batch[dp_rank].prepare_forward_input( - 0, batch[dp_rank].size(), args_, threadpool_.get()))); + batched_inputs.emplace_back(std::move( + batch[dp_rank].prepare_forward_input(args_, threadpool_.get()))); dp_global_token_nums[dp_rank] = batched_inputs[dp_rank].flatten_tokens_vec.size(); global_empty_kv_cache = diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 968d0458..b07ac742 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -185,7 +185,7 @@ message ForwardInput { bool global_empty_kv_cache = 21; repeated TransferKVInfo transfer_kv_infos = 22; repeated Embeddings embeds = 23; - uint32 prefill_seq_len = 24; + // uint32 prefill_seq_len = 24; repeated int32 embedding_ids = 25; repeated int32 extra_token_ids = 26; EplbInfo eplb_info =27;