Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
input_params.q_seq_lens_vec = std::move(state_.q_seq_lens);
input_params.new_cache_slots =
torch::tensor(state_.new_token_slot_ids, torch::kInt);
input_params.decode_seq_range =
util::find_ones_indices(input_params.q_seq_lens_vec);

// for flashinfer
input_params.paged_kv_indptr =
Expand Down
14 changes: 3 additions & 11 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ struct ModelInputParams {
params.block_tables = safe_to(block_tables, device, true);
params.kv_seq_lens_vec = kv_seq_lens_vec;
params.q_seq_lens_vec = q_seq_lens_vec;
params.decode_seq_range = decode_seq_range;

params.input_embedding = safe_to(input_embedding, device);

Expand Down Expand Up @@ -153,7 +152,8 @@ struct ModelInputParams {
<< " , q_max_seq_len is " << q_max_seq_len;
LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec;
LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec;
LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range;
LOG(INFO) << "ModelInputParams: batch_forward_type is "
<< batch_forward_type.to_string();
print_tensor(kv_seq_lens, "ModelInputParams: kv_seq_lens", 4);
print_tensor(q_seq_lens, "ModelInputParams: q_seq_lens", 4);
print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4);
Expand All @@ -172,15 +172,7 @@ struct ModelInputParams {
torch::Tensor kv_seq_lens;
std::vector<int> kv_seq_lens_vec;
std::vector<int> q_seq_lens_vec;
// Range of decode sequence indices in the batch [start, end].
// Decode sequences are identified by q_seq_lens == 1,
// prefill sequences by q_seq_lens > 1 .
// Used to determine whether to use prefill_node_ or
// decode_node_ in NPU layers
// Values: {-1, -1} if no decode requests (all prefill),
// {0, batch_size-1} if all decode requests,
// {start_idx, end_idx} if mixed prefill/decode requests
std::pair<int, int> decode_seq_range;

// max length for qkv.
int32_t kv_max_seq_len = 0;
int32_t q_max_seq_len = 0;
Expand Down
4 changes: 1 addition & 3 deletions xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,9 +1090,7 @@ torch::Tensor Glm4MoeDecoderImpl::forward(torch::Tensor& x,
std::atomic<bool>* event_flag,
int node_id) {
atb::Status st;
bool is_prefill = input_params.decode_seq_range.second !=
input_params.q_seq_lens.size(0) - 1;
if (is_prefill) {
if (!input_params.batch_forward_type.is_decode()) {
build_node_variant_pack(prefill_node_,
x,
cos_pos,
Expand Down
3 changes: 1 addition & 2 deletions xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x,
int node_id) {
atb::Status st;

if (input_params.decode_seq_range.second !=
input_params.q_seq_lens.size(0) - 1) {
if (!input_params.batch_forward_type.is_decode()) {
build_node_variant_pack(prefill_node_,
x,
cos_pos,
Expand Down
3 changes: 1 addition & 2 deletions xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,7 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward(torch::Tensor& x,
std::atomic<bool>* event_flag,
int node_id) {
atb::Status st;
if (input_params.decode_seq_range.second !=
input_params.q_seq_lens.size(0) - 1) {
if (!input_params.batch_forward_type.is_decode()) {
// mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr);
build_node_variant_pack(prefill_node_,
x,
Expand Down
3 changes: 1 addition & 2 deletions xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,7 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x,
std::atomic<bool>* event_flag,
int node_id) {
atb::Status st;
if (input_params.decode_seq_range.second !=
input_params.q_seq_lens.size(0) - 1) {
if (!input_params.batch_forward_type.is_decode()) {
// if (input_params.empty_kv_cache) {
// mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr);
build_node_variant_pack(prefill_node_,
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/runtime/acl_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ bool AclGraph::capture(CausalLM* model,
graph_params.q_seq_lens_vec[i] = 1;
}
graph_params.num_sequences = num_tokens_;
graph_params.decode_seq_range = {0, num_tokens_ - 1};
graph_params.batch_forward_type = BatchForwardType::DECODE;

graph_params.new_cache_slots =
persistent_param_.persistent_new_cache_slots(num_tokens_);
Expand Down
9 changes: 2 additions & 7 deletions xllm/core/runtime/forward_shared_memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -816,12 +816,7 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,
forward_input.positions =
create_2d_tensor(std::move(raw_input.m_positions_vec), torch::kInt);
}
std::pair<int, int> decode_seq_range{0, 0};
#if defined(USE_NPU)
if (raw_input.q_seq_lens.size() >= 1) {
decode_seq_range = util::find_ones_indices(raw_input.q_seq_lens);
}
#endif

auto& input_params = forward_input.input_params;
input_params.empty_kv_cache = raw_input.empty_kv_cache;
input_params.global_empty_kv_cache = raw_input.global_empty_kv_cache;
Expand All @@ -841,7 +836,7 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,

input_params.new_cache_slots =
torch::tensor(std::move(raw_input.new_token_slot_ids), tensor_options);
input_params.decode_seq_range = decode_seq_range;

util::pad_2d_vector(raw_input.block_tables_vec, 0);
input_params.block_tables =
create_2d_tensor(std::move(raw_input.block_tables_vec), torch::kInt);
Expand Down
8 changes: 1 addition & 7 deletions xllm/core/runtime/params_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
forward_inputs.acc_logprob = torch::tensor(
acc_logprob_vec,
torch::dtype(torch::kFloat32).device(torch::kCPU).pinned_memory(true));
std::pair<int, int> decode_seq_range{0, 0};
#if defined(USE_NPU)
if (q_seq_lens.size() >= 1) {
decode_seq_range = util::find_ones_indices(q_seq_lens);
}
#endif

auto& input_params = forward_inputs.input_params;
input_params.empty_kv_cache = pb_forward_input->empty_kv_cache();
input_params.global_empty_kv_cache =
Expand All @@ -206,7 +201,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,

input_params.new_cache_slots =
torch::tensor(new_token_slot_ids, tensor_options);
input_params.decode_seq_range = decode_seq_range;

util::pad_2d_vector(block_tables_vec, /*pad_value=*/0);
input_params.block_tables =
Expand Down
4 changes: 3 additions & 1 deletion xllm/core/runtime/speculative_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,13 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
input_params.q_seq_lens = torch::tensor(q_seq_lens_vec, int_options);
input_params.new_cache_slots = torch::tensor(new_token_slot_ids, int_options);
if (!FLAGS_enable_atb_spec_kernel) {
input_params.batch_forward_type = BatchForwardType::CHUNKED_PREFILL;
util::pad_2d_vector(block_tables_vec, /*pad_value=*/0);
input_params.block_tables =
create_2d_tensor(block_tables_vec, torch::kInt).to(device_);
} else {
input_params.batch_forward_type = BatchForwardType::DECODE;
}
input_params.decode_seq_range.second = input_params.num_sequences - 1;

// update the sampling_params
update_sampling_params(
Expand Down