Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions xllm/core/framework/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,39 @@ void Batch::add(Sequence* sequence, uint32_t allowed_max_token) {
input_embeddings_vec_.emplace_back(input_embedding);

const auto& mm_data = sequence->get_mm_data();
// if (sequence->is_prefill_stage() && mm_data.valid()) // TODO:Compatible
// With Chunked Prefill
if ((sequence->kv_state().kv_cache_tokens_num() <
sequence->num_prompt_tokens()) &&
mm_data.valid())
// if (sequence->is_chunked_prefill_stage() && mm_data.valid())
// TODO:Compatible With Chunked Prefill
if ((sequence->stage() == SequenceStage::PREFILL) && mm_data.valid()) {
mm_data_vec_.emplace_back(mm_data);
}
}

void Batch::update_forward_type(Sequence* sequence) {
auto stage = sequence->stage();
switch (batch_forward_type_.value()) {
case BatchForwardType::PREFILL:
if (stage == SequenceStage::CHUNKED_PREFILL) {
batch_forward_type_ = BatchForwardType::CHUNKED_PREFILL;
} else if (stage == SequenceStage::DECODE) {
batch_forward_type_ = BatchForwardType::MIXED;
}
break;
case BatchForwardType::CHUNKED_PREFILL:
if (stage == SequenceStage::DECODE) {
batch_forward_type_ = BatchForwardType::MIXED;
}
break;
case BatchForwardType::DECODE:
if (stage != SequenceStage::DECODE) {
batch_forward_type_ = BatchForwardType::MIXED;
}
break;
case BatchForwardType::MIXED:
break;
case BatchForwardType::EMPTY:
batch_forward_type_ = BatchForwardType(static_cast<int32_t>(stage));
break;
}
}

void Batch::add(const std::vector<Sequence*>& sequences) {
Expand All @@ -75,7 +102,8 @@ ForwardInput Batch::prepare_forward_input(uint32_t num_decoding_tokens,
mm_data_vec_,
swap_block_transfer_infos_,
batch_id_,
&args);
&args,
batch_forward_type_);
return builder.build_forward_input(num_decoding_tokens,
min_decoding_batch_size);
}
Expand Down Expand Up @@ -180,6 +208,7 @@ RawForwardInput Batch::prepare_forward_input(uint32_t start_idx,
swap_block_transfer_infos_,
batch_id_,
&args,
batch_forward_type_,
thread_pool);
return builder.build_raw_forward_input(start_idx, end_idx);
}
Expand Down Expand Up @@ -282,7 +311,7 @@ bool Batch::update_sequence_state(Sequence* seq, bool replace_fake_token) {
// prefill-or-not state of last stage, otherwise, we need the state
// of current stage.
if (FLAGS_enable_chunked_prefill) {
if (!replace_fake_token && seq->is_prefill_stage()) {
if (!replace_fake_token && seq->is_chunked_prefill_stage()) {
seq->pre_scheduled_step_prefill_queue().push(true);
// if not replace_fake_token, pop out here to avoid endless growth
if (seq->pre_scheduled_step_prefill_queue().size() > 2) {
Expand Down
12 changes: 4 additions & 8 deletions xllm/core/framework/batch/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ limitations under the License.
#include <limits>
#include <vector>

#include "framework/batch/batch_forward_type.h"
#include "framework/request/mm_data.h"
#include "framework/request/request.h"
#include "framework/request/sequence.h"
Expand Down Expand Up @@ -53,6 +54,8 @@ class Batch {
sequence_groups_.push_back(sequence_group);
}

void update_forward_type(Sequence* sequence);

void set_swap_block_transfer_infos(
std::vector<BlockTransferInfo>* swap_block_transfer_infos) {
swap_block_transfer_infos_ = swap_block_transfer_infos;
Expand Down Expand Up @@ -113,12 +116,6 @@ class Batch {
return allowed_max_tokens_;
}

void set_batch_prefill_status(const bool all_seqs_in_prefill) {
all_seqs_in_prefill_ = all_seqs_in_prefill;
}

bool get_batch_prefill_status() const { return all_seqs_in_prefill_; }

std::map<uint32_t, uint32_t> cal_seq_exchange_index_test(
std::vector<uint32_t>& kv_cache_tokens_num) {
return cal_seq_exchange_index(kv_cache_tokens_num);
Expand Down Expand Up @@ -152,8 +149,7 @@ class Batch {
// mm_data in the batch
std::vector<MMData> mm_data_vec_;

// all sequences in this batch are in prefill stage
bool all_seqs_in_prefill_ = false;
BatchForwardType batch_forward_type_;

uint64_t batch_id_ = UNINITIALIZED_BATCH_ID;
};
Expand Down
5 changes: 1 addition & 4 deletions xllm/core/framework/batch/batch_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ std::vector<Batch> BatchFactory::create_batches(
// if dp enabled, each sequence is required to
// dispatch to the same rank in the whole lifetime
batches[sequence->dp_rank()].add(sequence, token_budget);
if (!((sequence->stage() == SequenceStage::DECODE) &&
(sequence->kv_state().kv_cache_tokens_num() > 0))) {
batches[sequence->dp_rank()].set_batch_prefill_status(true);
}
batches[sequence->dp_rank()].update_forward_type(sequence);
}

if (is_beam_search(running_requests)) {
Expand Down
87 changes: 87 additions & 0 deletions xllm/core/framework/batch/batch_forward_type.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

namespace xllm {

class BatchForwardType {
public:
enum Value : int32_t {
// Prefill without using kv cache.
PREFILL = 0,
// Chunked prefill using kv cache.
// No decode sequence in this type.
CHUNKED_PREFILL = 1,
// Decode one token.
// No prefill sequence in this type.
DECODE = 2,
// Mixed prefill and decode in one batch when doing chunked prefill.
MIXED = 3,
// No sequence to forward.
EMPTY = 4,
};

BatchForwardType() : value_(EMPTY) {}

BatchForwardType(int32_t v) : value_(static_cast<Value>(v)) {}

constexpr BatchForwardType(Value v) : value_(v) {}

BatchForwardType& operator=(Value v) {
value_ = v;
return *this;
}

int32_t value() const { return value_; }

bool is_prefill() const { return (value_ == PREFILL); }

bool is_chunked_prefill() const { return (value_ == CHUNKED_PREFILL); }

bool no_decode() const {
return (value_ == PREFILL || value_ == CHUNKED_PREFILL);
}

bool has_decode() const { return (value_ == DECODE || value_ == MIXED); }

bool is_decode() const { return (value_ == DECODE); }

bool is_mixed() const { return (value_ == MIXED); }

bool is_empty() const { return (value_ == EMPTY); }

std::string to_string() const {
switch (value_) {
case PREFILL:
return "PREFILL";
case CHUNKED_PREFILL:
return "CHUNKED_PREFILL";
case DECODE:
return "DECODE";
case MIXED:
return "MIXED";
case EMPTY:
return "EMPTY";
default:
return "UNKNOWN";
}
}

private:
Value value_;
};
} // namespace xllm
7 changes: 5 additions & 2 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ BatchInputBuilder::BatchInputBuilder(
std::vector<BlockTransferInfo>* swap_block_transfer_infos,
const uint64_t batch_id,
const ModelArgs* args,
BatchForwardType batch_forward_type,
ThreadPool* thread_pool)
: sequences_(sequences),
allowed_max_tokens_(allowed_max_tokens),
Expand All @@ -65,6 +66,7 @@ BatchInputBuilder::BatchInputBuilder(
use_mrope_ = (args_->rope_scaling_rope_type() == "mrope");
}
write_block_ids_.clear();
state_.batch_forward_type = batch_forward_type;
}

ForwardInput BatchInputBuilder::build_forward_input(
Expand Down Expand Up @@ -305,7 +307,7 @@ void BatchInputBuilder::process_single_sequence(
}

// Track prefill sequences
if (sequence->is_prefill_stage()) {
if (sequence->is_chunked_prefill_stage()) {
state.prefill_seq_len++;
}

Expand Down Expand Up @@ -552,6 +554,7 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {

auto& input_params = forward_input.input_params;
input_params.empty_kv_cache = state_.empty_kv_cache;
input_params.batch_forward_type = state_.batch_forward_type;
input_params.num_sequences = state_.block_tables_vec.size();
input_params.kv_max_seq_len = state_.max_seq_len;
input_params.q_max_seq_len = state_.q_max_seq_len;
Expand Down Expand Up @@ -645,7 +648,7 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
raw_forward_input.unique_token_lens_vec =
std::move(state_.unique_token_lens_vec);
raw_forward_input.empty_kv_cache = state_.empty_kv_cache;
// raw_forward_input.global_empty_kv_cache = ;
raw_forward_input.batch_forward_type = state_.batch_forward_type;
raw_forward_input.max_seq_len = state_.max_seq_len;
raw_forward_input.q_max_seq_len = state_.q_max_seq_len;
raw_forward_input.seq_lens = std::move(state_.seq_lens);
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class BatchInputBuilder {
std::vector<BlockTransferInfo>* swap_block_transfer_infos,
const uint64_t batch_id,
const ModelArgs* args,
BatchForwardType batch_forward_type,
ThreadPool* thread_pool = nullptr);

ForwardInput build_forward_input(uint32_t num_decoding_tokens,
Expand Down Expand Up @@ -77,6 +78,7 @@ class BatchInputBuilder {
std::vector<int32_t> unique_token_lens_vec;

// Sequence metadata
BatchForwardType batch_forward_type;
bool empty_kv_cache = true;
uint32_t max_seq_len = 0;
uint32_t q_max_seq_len = 0;
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/batch/mposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
namespace xllm {

torch::Tensor MPositionHelper::get_positions() {
// if (seq_.is_prefill_stage()) {
// if (seq_.is_chunked_prefill_stage()) {
if (seq_.kv_state().kv_cache_tokens_num() < seq_.num_prompt_tokens()) {
auto& mm_data = seq_.get_mm_data();

Expand Down
3 changes: 3 additions & 0 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#if defined(USE_NPU)
#include "platform/npu/npu_layer_synchronizer.h"
#endif
#include "framework/batch/batch_forward_type.h"
#include "framework/request/mm_data.h"
#include "npu_dp_ep_padding.h"
#include "util/tensor_helper.h"
Expand Down Expand Up @@ -88,6 +89,7 @@ struct ModelInputParams {
ModelInputParams params;
params.empty_kv_cache = empty_kv_cache;
params.global_empty_kv_cache = global_empty_kv_cache;
params.batch_forward_type = batch_forward_type;
params.num_sequences = num_sequences;
params.kv_max_seq_len = kv_max_seq_len;
params.q_max_seq_len = q_max_seq_len;
Expand Down Expand Up @@ -163,6 +165,7 @@ struct ModelInputParams {
}
// whether the kv-cache is empty for all sequences.
bool empty_kv_cache = true;
BatchForwardType batch_forward_type;

// total number of sequences in the batch
int32_t num_sequences = 0;
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/framework/request/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ class Request : public RequestBase {
return state_.sampling_param.beam_width > 1;
}

bool is_prefill_stage() const { return sequences_group_->is_prefill_stage(); }
bool is_chunked_prefill_stage() const {
return sequences_group_->is_chunked_prefill_stage();
}

private:
RequestState state_;
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/framework/request/sequence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void Sequence::append_token(const Token& token) {
CHECK_LT(num_tokens_, tokens_.size())
<< "exceed the token capacity of the sequence";
CHECK(!finished_) << "cannot append token to a finished sequence";
CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_prefill_stage())
CHECK(kv_state_.kv_cache_tokens_num() > 0 && !is_chunked_prefill_stage())
<< "cannot append token to a prefill sequence";

if (!sequence_params_.enable_schedule_overlap) {
Expand Down
22 changes: 17 additions & 5 deletions xllm/core/framework/request/sequence.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,14 @@ limitations under the License.

namespace xllm {

enum class SequenceStage : int8_t { PREFILL = 0, DECODE = 1 };
enum class SequenceStage : int8_t {
// Prefill without using kv cache.
PREFILL = 0,
// Chunked prefill using kv cache.
CHUNKED_PREFILL = 1,
// Decode one token.
DECODE = 2
};

struct SequenceParams {
// max tokens count in the sequence.
Expand Down Expand Up @@ -96,12 +103,17 @@ class Sequence final {
}

// check if in prefill stage
bool is_prefill_stage() const { return stage() == SequenceStage::PREFILL; }
bool is_chunked_prefill_stage() const {
return stage() == SequenceStage::CHUNKED_PREFILL;
}

// get the sequence stage
SequenceStage stage() const {
if ((kv_state_.kv_cache_tokens_num() <
std::max(volatile_num_prompt_tokens_, num_prompt_tokens())) &&
kv_state_.kv_cache_tokens_num() > 0) {
if (kv_state_.kv_cache_tokens_num() <
std::max(volatile_num_prompt_tokens_, num_prompt_tokens())) {
if (kv_state_.kv_cache_tokens_num() > 0) {
return SequenceStage::CHUNKED_PREFILL;
}
return SequenceStage::PREFILL;
}
return SequenceStage::DECODE;
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/framework/request/sequences_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ class SequencesGroup {

int32_t dp_rank() { return sequences_[0]->dp_rank(); }

bool is_prefill_stage() const { return sequences_[0]->is_prefill_stage(); }
bool is_chunked_prefill_stage() const {
return sequences_[0]->is_chunked_prefill_stage();
}

private:
void add();
Expand Down
1 change: 1 addition & 0 deletions xllm/core/runtime/forward_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ struct RawForwardInput {
std::vector<int32_t> unique_token_lens_vec;
bool empty_kv_cache = true;
bool global_empty_kv_cache = true;
BatchForwardType batch_forward_type;
uint32_t max_seq_len;
uint32_t q_max_seq_len;
std::vector<int32_t> seq_lens;
Expand Down
Loading
Loading