diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 37a55bbf..f3e83a5d 100755 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -554,8 +554,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.q_seq_lens_vec = std::move(state_.q_seq_lens); input_params.new_cache_slots = torch::tensor(state_.new_token_slot_ids, torch::kInt); - input_params.decode_seq_range = - util::find_ones_indices(input_params.q_seq_lens_vec); // for flashinfer input_params.paged_kv_indptr = diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 857f05f1..88306a2b 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -101,7 +101,6 @@ struct ModelInputParams { params.block_tables = safe_to(block_tables, device, true); params.kv_seq_lens_vec = kv_seq_lens_vec; params.q_seq_lens_vec = q_seq_lens_vec; - params.decode_seq_range = decode_seq_range; params.input_embedding = safe_to(input_embedding, device); @@ -153,7 +152,8 @@ struct ModelInputParams { << " , q_max_seq_len is " << q_max_seq_len; LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec; LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec; - LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range; + LOG(INFO) << "ModelInputParams: batch_forward_type is " + << batch_forward_type.to_string(); print_tensor(kv_seq_lens, "ModelInputParams: kv_seq_lens", 4); print_tensor(q_seq_lens, "ModelInputParams: q_seq_lens", 4); print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4); @@ -172,15 +172,7 @@ struct ModelInputParams { torch::Tensor kv_seq_lens; std::vector kv_seq_lens_vec; std::vector q_seq_lens_vec; - // Range of decode sequence indices in the batch [start, end]. - // Decode sequences are identified by q_seq_lens == 1, - // prefill sequences by q_seq_lens > 1 . - // Used to determine whether to use prefill_node_ or - // decode_node_ in NPU layers - // Values: {-1, -1} if no decode requests (all prefill), - // {0, batch_size-1} if all decode requests, - // {start_idx, end_idx} if mixed prefill/decode requests - std::pair decode_seq_range; + // max length for qkv. int32_t kv_max_seq_len = 0; int32_t q_max_seq_len = 0; diff --git a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp index cbadf038..16d51fde 100644 --- a/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp +++ b/xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp @@ -1090,9 +1090,7 @@ torch::Tensor Glm4MoeDecoderImpl::forward(torch::Tensor& x, std::atomic* event_flag, int node_id) { atb::Status st; - bool is_prefill = input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1; - if (is_prefill) { + if (!input_params.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp index 9696353c..eb7794c2 100644 --- a/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp @@ -277,8 +277,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x, int node_id) { atb::Status st; - if (input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1) { + if (!input_params.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index ffbf45bd..9b4bb071 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -404,8 +404,7 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward(torch::Tensor& x, std::atomic* event_flag, int node_id) { atb::Status st; - if (input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1) { + if (!input_params.batch_forward_type.is_decode()) { // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); build_node_variant_pack(prefill_node_, x, diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index f7ae8923..4966c62b 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -519,8 +519,7 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x, std::atomic* event_flag, int node_id) { atb::Status st; - if (input_params.decode_seq_range.second != - input_params.q_seq_lens.size(0) - 1) { + if (!input_params.batch_forward_type.is_decode()) { // if (input_params.empty_kv_cache) { // mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr); build_node_variant_pack(prefill_node_, diff --git a/xllm/core/runtime/acl_graph_executor_impl.cpp b/xllm/core/runtime/acl_graph_executor_impl.cpp index 9556b1fa..56fffe4e 100644 --- a/xllm/core/runtime/acl_graph_executor_impl.cpp +++ b/xllm/core/runtime/acl_graph_executor_impl.cpp @@ -629,7 +629,7 @@ bool AclGraph::capture(CausalLM* model, graph_params.q_seq_lens_vec[i] = 1; } graph_params.num_sequences = num_tokens_; - graph_params.decode_seq_range = {0, num_tokens_ - 1}; + graph_params.batch_forward_type = BatchForwardType::DECODE; graph_params.new_cache_slots = persistent_param_.persistent_new_cache_slots(num_tokens_); diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index ed531a35..2139bc12 100755 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -816,12 +816,7 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, forward_input.positions = create_2d_tensor(std::move(raw_input.m_positions_vec), torch::kInt); } - std::pair decode_seq_range{0, 0}; -#if defined(USE_NPU) - if (raw_input.q_seq_lens.size() >= 1) { - decode_seq_range = util::find_ones_indices(raw_input.q_seq_lens); - } -#endif + auto& input_params = forward_input.input_params; input_params.empty_kv_cache = raw_input.empty_kv_cache; input_params.global_empty_kv_cache = raw_input.global_empty_kv_cache; @@ -841,7 +836,7 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.new_cache_slots = torch::tensor(std::move(raw_input.new_token_slot_ids), tensor_options); - input_params.decode_seq_range = decode_seq_range; + util::pad_2d_vector(raw_input.block_tables_vec, 0); input_params.block_tables = create_2d_tensor(std::move(raw_input.block_tables_vec), torch::kInt); diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 077b725c..54b6dc85 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -177,12 +177,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, forward_inputs.acc_logprob = torch::tensor( acc_logprob_vec, torch::dtype(torch::kFloat32).device(torch::kCPU).pinned_memory(true)); - std::pair decode_seq_range{0, 0}; -#if defined(USE_NPU) - if (q_seq_lens.size() >= 1) { - decode_seq_range = util::find_ones_indices(q_seq_lens); - } -#endif + auto& input_params = forward_inputs.input_params; input_params.empty_kv_cache = pb_forward_input->empty_kv_cache(); input_params.global_empty_kv_cache = @@ -206,7 +201,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.new_cache_slots = torch::tensor(new_token_slot_ids, tensor_options); - input_params.decode_seq_range = decode_seq_range; util::pad_2d_vector(block_tables_vec, /*pad_value=*/0); input_params.block_tables = diff --git a/xllm/core/runtime/speculative_worker_impl.cpp b/xllm/core/runtime/speculative_worker_impl.cpp index 9975be01..b98259d8 100644 --- a/xllm/core/runtime/speculative_worker_impl.cpp +++ b/xllm/core/runtime/speculative_worker_impl.cpp @@ -559,11 +559,13 @@ void SpeculativeWorkerImpl::prepare_validate_inputs( input_params.q_seq_lens = torch::tensor(q_seq_lens_vec, int_options); input_params.new_cache_slots = torch::tensor(new_token_slot_ids, int_options); if (!FLAGS_enable_atb_spec_kernel) { + input_params.batch_forward_type = BatchForwardType::CHUNKED_PREFILL; util::pad_2d_vector(block_tables_vec, /*pad_value=*/0); input_params.block_tables = create_2d_tensor(block_tables_vec, torch::kInt).to(device_); + } else { + input_params.batch_forward_type = BatchForwardType::DECODE; } - input_params.decode_seq_range.second = input_params.num_sequences - 1; // update the sampling_params update_sampling_params(