Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion xllm/core/framework/request/sequence_kv_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
if (blocks.empty()) {
return;
}

// The number of matched blocks may be fewer than the number of blocks held by
// the sequence itself. In this case, try to replace the blocks computed by
// the sequence with blocks from the prefix_cache and release the computed
Expand Down Expand Up @@ -86,6 +85,10 @@ void KVCacheState::add_shared_kv_blocks(std::vector<Block>&& blocks,
CHECK_GT(block_size, 0);
num_shared_tokens =
((current_total_num_tokens - 1) / block_size) * block_size;
if (num_owned_shared_blocks_ > 0) {
num_owned_shared_blocks_--;
blocks_.pop_back();
}
}
CHECK_LT(num_shared_tokens, current_total_num_tokens);
// update the kv cache position
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/runtime/llm_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ std::optional<ForwardOutput> LLMWorkerImpl::step(
// should be in same prefill stage, so, to judge empty_kv_cache,
// just use micro batch 0 here
if (options_.enable_speculative_decode() && !is_spec_draft_) {
if (input_params_micro_batches[0].q_seq_lens_vec[0] > 1) {
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
output.sample_output.embeddings = hidden_states;
} else if (concated_sampling_params.sample_idxes.defined()) {
// auto sample_idxes =
Expand Down
44 changes: 27 additions & 17 deletions xllm/core/runtime/speculative_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(
}

// TODO: support data parallel case
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
return step_prefill(inputs);
} else {
return step_decode(inputs);
Expand All @@ -182,7 +182,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step(

std::optional<ForwardOutput> SpeculativeWorkerImpl::step_empty(
const BatchedForwardInputs& inputs) {
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
auto output = impl_->step(inputs);
auto draft_output = draft_impl_->step(inputs);
return output;
Expand Down Expand Up @@ -230,7 +230,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
if (token_offset > 0) {
prefill_inputs.micro_inputs[i].input_params.mm_data = MMData(
MMType::EMBEDDING,
{{"embedding", embeddings.narrow(0, token_start_idx, token_offset)}});
{{"embedding",
embeddings.narrow(0, token_start_idx, token_offset).clone()}});
}
if (next_tokens.defined()) {
auto& token_ids = prefill_inputs.micro_inputs[i].token_ids;
Expand Down Expand Up @@ -293,6 +294,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_prefill(
void SpeculativeWorkerImpl::prepare_prefill_inputs(
const BatchedForwardInputs& inputs,
BatchedForwardInputs& prefill_inputs) {
prefill_inputs.micro_inputs.clear();
prefill_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
auto& input = inputs.micro_inputs[i];
Expand All @@ -308,16 +310,16 @@ void SpeculativeWorkerImpl::prepare_prefill_inputs(
int32_t start_idx = 0;
std::vector<int32_t> new_token_ids;
new_token_ids.reserve(input.token_ids.numel());
for (size_t i = 0; i < input_params.num_sequences; ++i) {
for (size_t j = 0; j < input_params.num_sequences; ++j) {
int32_t q_len = 0;
q_len = input_params.q_seq_lens_vec[i];
q_len = input_params.q_seq_lens_vec[j];
Slice<int32_t> tokens_ids_slice_i =
tokens_ids_slice.slice(start_idx + 1, start_idx + q_len);
start_idx += q_len;
new_token_ids.insert(new_token_ids.end(),
tokens_ids_slice_i.begin(),
tokens_ids_slice_i.end());
new_token_ids.emplace_back(extra_token_ids[i]);
new_token_ids.emplace_back(extra_token_ids[j]);
}
prefill_input.token_ids =
torch::tensor(new_token_ids, prefill_input.positions.options());
Expand Down Expand Up @@ -359,7 +361,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
// final step
prepare_validate_inputs(inputs, validate_inputs, true);
} else {
prepare_draft_inputs(draft_inputs, next_step_input, 1, device_);
if (i == 0) {
prepare_draft_inputs(inputs, next_step_input, 1, device_);
} else {
prepare_draft_inputs(draft_inputs, next_step_input, 1, device_);
}
}
draft_outputs.push_back(std::move(future).get().value());
// update input of next step
Expand All @@ -368,8 +374,8 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
auto last_output = draft_outputs.back().sample_output;
auto start_idx = 0;
auto token_start_idx = 0;
for (auto i = 0; i < draft_inputs.micro_inputs.size(); ++i) {
auto& draft_input = draft_inputs.micro_inputs[i];
for (auto j = 0; j < draft_inputs.micro_inputs.size(); ++j) {
auto& draft_input = draft_inputs.micro_inputs[j];
auto offset = draft_input.input_params.num_sequences;
auto token_offset = draft_input.token_ids.size(0);
draft_input.token_ids = safe_to(
Expand All @@ -379,6 +385,7 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
MMType::EMBEDDING,
{{"embedding",
last_output.embeddings.narrow(0, token_start_idx, token_offset)
.clone()
.to(device_)}});
}
start_idx += offset;
Expand All @@ -394,9 +401,11 @@ std::optional<ForwardOutput> SpeculativeWorkerImpl::step_decode(
auto next_tokens =
safe_to(draft_output.sample_output.next_tokens, torch::kInt);
int32_t start_idx = 0;
for (auto i = 0; i < validate_inputs.micro_inputs.size(); ++i) {
int32_t offset = draft_inputs.micro_inputs[i].input_params.num_sequences;
auto& validate_input = validate_inputs.micro_inputs[i];
for (auto j = 0; j < validate_inputs.micro_inputs.size(); ++j) {
int32_t offset =
validate_inputs.micro_inputs[j].input_params.num_sequences /
(options_.num_speculative_tokens() + 1);
auto& validate_input = validate_inputs.micro_inputs[j];
auto& token_ids = validate_input.token_ids;
auto mask = (token_ids == -1 * (i + 1));
token_ids.masked_scatter_(mask, next_tokens.narrow(0, start_idx, offset));
Expand Down Expand Up @@ -447,9 +456,10 @@ void SpeculativeWorkerImpl::prepare_draft_inputs(
const int64_t offset,
const torch::Device device) {
// prepare input for MTP in decoding phase (Like Eagle).
draft_inputs.micro_inputs.clear();
draft_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
auto& input = inputs.micro_inputs[i];
for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) {
auto& input = inputs.micro_inputs[idx];
ForwardInput draft_input = input.to(device, dtype_);

auto& input_params = draft_input.input_params;
Expand Down Expand Up @@ -504,8 +514,8 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
BatchedForwardInputs& validate_inputs,
bool enable_schedule_overlap) {
validate_inputs.micro_inputs.reserve(inputs.micro_inputs.size());
for (auto i = 0; i < inputs.micro_inputs.size(); ++i) {
auto& input = inputs.micro_inputs[i];
for (auto idx = 0; idx < inputs.micro_inputs.size(); ++idx) {
auto& input = inputs.micro_inputs[idx];

ForwardInput validate_input = input.to(device_, dtype_);
auto& input_params = validate_input.input_params;
Expand Down Expand Up @@ -823,7 +833,7 @@ void SpeculativeWorkerImpl::update_sampling_params(
void SpeculativeWorkerImpl::prepare_work_before_execute(
const BatchedForwardInputs& inputs,
BatchedForwardInputs& processed_inputs) {
if (inputs.micro_inputs[0].input_params.q_seq_lens_vec[0] > 1) {
if (check_is_prefill(inputs.micro_inputs[0].input_params.q_seq_lens_vec)) {
WorkerImpl::prepare_work_before_execute(inputs, processed_inputs);
} else {
if (enable_schedule_overlap()) {
Expand Down
9 changes: 9 additions & 0 deletions xllm/core/runtime/worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -759,4 +759,13 @@ int64_t WorkerImpl::get_active_activation_memory() {
.active_activation_memory;
}

bool WorkerImpl::check_is_prefill(const std::vector<int>& q_seq_lens_vec) {
for (auto q_len : q_seq_lens_vec) {
if (q_len > 1) {
return true;
}
}
return false;
}

} // namespace xllm
2 changes: 2 additions & 0 deletions xllm/core/runtime/worker_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class WorkerImpl {

torch::ScalarType dtype() const { return dtype_; }

bool check_is_prefill(const std::vector<int>& q_seq_lens_vec);

int32_t hidden_size() const {
return context_.get_model_args().hidden_size();
}
Expand Down