diff --git a/xllm/core/framework/request/sequence_kv_state.cpp b/xllm/core/framework/request/sequence_kv_state.cpp index cba11cb99..883eff85b 100644 --- a/xllm/core/framework/request/sequence_kv_state.cpp +++ b/xllm/core/framework/request/sequence_kv_state.cpp @@ -58,7 +58,6 @@ void KVCacheState::add_shared_kv_blocks(std::vector&& blocks, if (blocks.empty()) { return; } - // The number of matched blocks may be fewer than the number of blocks held by // the sequence itself. In this case, try to replace the blocks computed by // the sequence with blocks from the prefix_cache and release the computed @@ -86,6 +85,10 @@ void KVCacheState::add_shared_kv_blocks(std::vector&& blocks, CHECK_GT(block_size, 0); num_shared_tokens = ((current_total_num_tokens - 1) / block_size) * block_size; + if (num_owned_shared_blocks_ > 0) { + num_owned_shared_blocks_--; + blocks_.pop_back(); + } } CHECK_LT(num_shared_tokens, current_total_num_tokens); // update the kv cache position diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 9202abc7f..401a84bae 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -182,7 +182,7 @@ std::optional LLMWorkerImpl::step( // should be in same prefill stage, so, to judge empty_kv_cache, // just use micro batch 0 here if (options_.enable_speculative_decode() && !is_spec_draft_) { - if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) { + if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) { output.sample_output.embeddings = hidden_states; } else if (concated_sampling_params.sample_idxes.defined()) { // auto sample_idxes = diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 8e2c5a06f..37b8d9bc5 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -173,7 +173,7 @@ std::optional SpeculativeWorkerImpl::step( } // TODO: support data parallel case - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { + if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) { return step_prefill(inputs); } else { return step_decode(inputs); @@ -182,7 +182,7 @@ std::optional SpeculativeWorkerImpl::step( std::optional SpeculativeWorkerImpl::step_empty( const BatchedForwardInputs& inputs) { - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { + if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) { auto output = impl_->step(inputs); auto draft_output = draft_impl_->step(inputs); return output; @@ -230,7 +230,8 @@ std::optional SpeculativeWorkerImpl::step_prefill( if (token_offset > 0) { prefill_inputs.micro_inputs[i].input_params.mm_data = MMData( MMType::EMBEDDING, - {{"embedding", embeddings.narrow(0, token_start_idx, token_offset)}}); + {{"embedding", + embeddings.narrow(0, token_start_idx, token_offset).clone()}}); } if (next_tokens.defined()) { auto& token_ids = prefill_inputs.micro_inputs[i].token_ids; @@ -293,6 +294,7 @@ std::optional SpeculativeWorkerImpl::step_prefill( void SpeculativeWorkerImpl::prepare_prefill_inputs( const BatchedForwardInputs& inputs, BatchedForwardInputs& prefill_inputs) { + prefill_inputs.micro_inputs.clear(); prefill_inputs.micro_inputs.reserve(inputs.micro_inputs.size()); for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { auto& input = inputs.micro_inputs[i]; @@ -308,16 +310,16 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs( int32_t start_idx = 0; std::vector new_token_ids; new_token_ids.reserve(input.token_ids.numel()); - for (size_t i = 0; i < input_params.num_sequences; ++i) { + for (size_t j = 0; j < input_params.num_sequences; ++j) { int32_t q_len = 0; - q_len = input_params.q_seq_lens_vec[i]; + q_len = input_params.q_seq_lens_vec[j]; Slice tokens_ids_slice_i = tokens_ids_slice.slice(start_idx + 1, start_idx + q_len); start_idx += q_len; new_token_ids.insert(new_token_ids.end(), tokens_ids_slice_i.begin(), tokens_ids_slice_i.end()); - new_token_ids.emplace_back(extra_token_ids[i]); + new_token_ids.emplace_back(extra_token_ids[j]); } prefill_input.token_ids = torch::tensor(new_token_ids, prefill_input.positions.options()); @@ -359,7 +361,11 @@ std::optional SpeculativeWorkerImpl::step_decode( // final step prepare_validate_inputs(inputs, validate_inputs, true); } else { - prepare_draft_inputs(draft_inputs, next_step_input, 1, device_); + if (i == 0) { + prepare_draft_inputs(inputs, next_step_input, 1, device_); + } else { + prepare_draft_inputs(draft_inputs, next_step_input, 1, device_); + } } draft_outputs.push_back(std::move(future).get().value()); // update input of next step @@ -368,8 +374,8 @@ std::optional SpeculativeWorkerImpl::step_decode( auto last_output = draft_outputs.back().sample_output; auto start_idx = 0; auto token_start_idx = 0; - for (auto i = 0; i < draft_inputs.micro_inputs.size(); ++i) { - auto& draft_input = draft_inputs.micro_inputs[i]; + for (auto j = 0; j < draft_inputs.micro_inputs.size(); ++j) { + auto& draft_input = draft_inputs.micro_inputs[j]; auto offset = draft_input.input_params.num_sequences; auto token_offset = draft_input.token_ids.size(0); draft_input.token_ids = safe_to( @@ -379,6 +385,7 @@ std::optional SpeculativeWorkerImpl::step_decode( MMType::EMBEDDING, {{"embedding", last_output.embeddings.narrow(0, token_start_idx, token_offset) + .clone() .to(device_)}}); } start_idx += offset; @@ -394,9 +401,11 @@ std::optional SpeculativeWorkerImpl::step_decode( auto next_tokens = safe_to(draft_output.sample_output.next_tokens, torch::kInt); int32_t start_idx = 0; - for (auto i = 0; i < validate_inputs.micro_inputs.size(); ++i) { - int32_t offset = draft_inputs.micro_inputs[i].input_params.num_sequences; - auto& validate_input = validate_inputs.micro_inputs[i]; + for (auto j = 0; j < validate_inputs.micro_inputs.size(); ++j) { + int32_t offset = + validate_inputs.micro_inputs[j].input_params.num_sequences / + (options_.num_speculative_tokens() + 1); + auto& validate_input = validate_inputs.micro_inputs[j]; auto& token_ids = validate_input.token_ids; auto mask = (token_ids == -1 * (i + 1)); token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset)); @@ -447,9 +456,10 @@ void SpeculativeWorkerImpl::prepare_draft_inputs( const int64_t offset, const torch::Device device) { // prepare input for MTP in decoding phase (Like Eagle). + draft_inputs.micro_inputs.clear(); draft_inputs.micro_inputs.reserve(inputs.micro_inputs.size()); - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - auto& input = inputs.micro_inputs[i]; + for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) { + auto& input = inputs.micro_inputs[idx]; ForwardInput draft_input = input.to(device, dtype_); auto& input_params = draft_input.input_params; @@ -504,8 +514,8 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( BatchedForwardInputs& validate_inputs, bool enable_schedule_overlap) { validate_inputs.micro_inputs.reserve(inputs.micro_inputs.size()); - for (auto i = 0; i < inputs.micro_inputs.size(); ++i) { - auto& input = inputs.micro_inputs[i]; + for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) { + auto& input = inputs.micro_inputs[idx]; ForwardInput validate_input = input.to(device_, dtype_); auto& input_params = validate_input.input_params; @@ -823,7 +833,7 @@ void SpeculativeWorkerImpl::update_sampling_params( void SpeculativeWorkerImpl::prepare_work_before_execute( const BatchedForwardInputs& inputs, BatchedForwardInputs& processed_inputs) { - if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) { + if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) { WorkerImpl::prepare_work_before_execute(inputs, processed_inputs); } else { if (enable_schedule_overlap()) { diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 47a781653..c7e144553 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -759,4 +759,13 @@ int64_t WorkerImpl::get_active_activation_memory() { .active_activation_memory; } +bool WorkerImpl::check_is_prefill(const std::vector& q_seq_lens_vec) { + for (auto q_len : q_seq_lens_vec) { + if (q_len > 1) { + return true; + } + } + return false; +} + } // namespace xllm diff --git a/xllm/core/runtime/worker_impl.h b/xllm/core/runtime/worker_impl.h index 63b1560e0..6640ecf44 100644 --- a/xllm/core/runtime/worker_impl.h +++ b/xllm/core/runtime/worker_impl.h @@ -166,6 +166,8 @@ class WorkerImpl { torch::ScalarType dtype() const { return dtype_; } + bool check_is_prefill(const std::vector& q_seq_lens_vec); + int32_t hidden_size() const { return context_.get_model_args().hidden_size(); }