diff --git a/.gitignore b/.gitignore index 9e05a5d..fc9e663 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ # rust Cargo.lock +/local \ No newline at end of file diff --git a/xllm_service/proto/xllm/chat.proto b/xllm_service/proto/xllm/chat.proto index 1fdecd2..f97ab97 100644 --- a/xllm_service/proto/xllm/chat.proto +++ b/xllm_service/proto/xllm/chat.proto @@ -113,6 +113,8 @@ message ChatRequest { repeated int32 token_ids = 26; Routing routing = 27; + + optional bool offline = 28; } message ChatLogProbData { diff --git a/xllm_service/proto/xllm/completion.proto b/xllm_service/proto/xllm/completion.proto index 9b1c78f..77ad486 100644 --- a/xllm_service/proto/xllm/completion.proto +++ b/xllm_service/proto/xllm/completion.proto @@ -90,6 +90,8 @@ message CompletionRequest { repeated int32 token_ids = 26; Routing routing = 27; + + optional bool offline = 28; } message LogProbs { diff --git a/xllm_service/proto/xllm_rpc_service.proto b/xllm_service/proto/xllm_rpc_service.proto index c61389d..a0e5094 100644 --- a/xllm_service/proto/xllm_rpc_service.proto +++ b/xllm_service/proto/xllm_rpc_service.proto @@ -144,6 +144,7 @@ service XllmRpcService { rpc GetInstanceInfo(InstanceID) returns (InstanceMetaInfo) {} rpc Heartbeat(HeartbeatRequest) returns (Status) {} rpc GetStaticDecodeList(InstanceID) returns (InstanceIDs) {} + rpc GetStaticPrefillList(InstanceID) returns (InstanceIDs) {} rpc GetConfig(Empty) returns (ServiceConfig) {} // xllm service receive response from decode instance directly in disagg pd mode. diff --git a/xllm_service/request/request.h b/xllm_service/request/request.h index 3637059..14e821a 100644 --- a/xllm_service/request/request.h +++ b/xllm_service/request/request.h @@ -35,6 +35,8 @@ struct Request { // whether to return usage bool include_usage = false; + bool offline = false; + // input prompt std::string prompt; diff --git a/xllm_service/rpc_service/service.cpp b/xllm_service/rpc_service/service.cpp index cc19893..c64fe83 100644 --- a/xllm_service/rpc_service/service.cpp +++ b/xllm_service/rpc_service/service.cpp @@ -49,6 +49,11 @@ std::vector XllmRpcServiceImpl::get_static_decode_list( return scheduler_->get_static_decode_list(instance_name); } +std::vector XllmRpcServiceImpl::get_static_prefill_list( + const std::string& instance_name) { + return scheduler_->get_static_prefill_list(instance_name); +} + bool XllmRpcServiceImpl::handle_generation( const llm::RequestOutput& request_output) { return scheduler_->handle_generation(request_output); @@ -126,6 +131,19 @@ void XllmRpcService::GetStaticDecodeList( } } +void XllmRpcService::GetStaticPrefillList( + google::protobuf::RpcController* cntl_base, + const proto::InstanceID* req, + proto::InstanceIDs* resp, + google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + std::vector prefill_list = + xllm_rpc_service_impl_->get_static_prefill_list(req->name()); + for (auto& p : prefill_list) { + *(resp->mutable_names()->Add()) = std::move(p); + } +} + void XllmRpcService::Generations(google::protobuf::RpcController* cntl_base, const proto::DisaggStreamGenerations* req, proto::StatusSet* resp, diff --git a/xllm_service/rpc_service/service.h b/xllm_service/rpc_service/service.h index 52b4e2c..87c1206 100644 --- a/xllm_service/rpc_service/service.h +++ b/xllm_service/rpc_service/service.h @@ -52,6 +52,9 @@ class XllmRpcServiceImpl final { std::vector get_static_decode_list( const std::string& prefill_name); + std::vector get_static_prefill_list( + const std::string& decode_name); + public: // handle generations from prefill/decode instance bool handle_generation(const llm::RequestOutput& request_output); @@ -103,6 +106,11 @@ class XllmRpcService : public proto::XllmRpcService { proto::InstanceIDs* resp, google::protobuf::Closure* done) override; + virtual void GetStaticPrefillList(google::protobuf::RpcController* cntl_base, + const proto::InstanceID* req, + proto::InstanceIDs* resp, + google::protobuf::Closure* done) override; + // xllm service receive response from decode instance directly in disagg pd // mode. This can eliminate the cost brought by forwarding through prefill. virtual void Generations(google::protobuf::RpcController* cntl_base, diff --git a/xllm_service/scheduler/managers/instance_mgr.cpp b/xllm_service/scheduler/managers/instance_mgr.cpp index 33c6194..68f9f33 100644 --- a/xllm_service/scheduler/managers/instance_mgr.cpp +++ b/xllm_service/scheduler/managers/instance_mgr.cpp @@ -169,6 +169,21 @@ std::vector InstanceMgr::get_static_decode_list( return decode_list; } +// TODO: refactor later, currently return all prefill instances +std::vector InstanceMgr::get_static_prefill_list( + const std::string& instance_name) { + std::vector prefill_list; + std::shared_lock lock(inst_mutex_); + for (auto& inst : instances_) { + if (inst.second.type == InstanceType::PREFILL || + inst.second.type == InstanceType::DEFAULT) { + prefill_list.emplace_back(inst.second.name); + } + } + + return prefill_list; +} + void InstanceMgr::get_load_metrics(LoadBalanceInfos* infos) { std::shared_lock inst_lock(inst_mutex_); std::shared_lock metric_lock(load_metric_mutex_); diff --git a/xllm_service/scheduler/managers/instance_mgr.h b/xllm_service/scheduler/managers/instance_mgr.h index 9598cdc..1c47af0 100644 --- a/xllm_service/scheduler/managers/instance_mgr.h +++ b/xllm_service/scheduler/managers/instance_mgr.h @@ -48,6 +48,9 @@ class InstanceMgr final { std::vector get_static_decode_list( const std::string& instance_name); + std::vector get_static_prefill_list( + const std::string& instance_name); + void get_load_metrics(LoadBalanceInfos* infos); std::shared_ptr get_channel(const std::string& instance_name); diff --git a/xllm_service/scheduler/scheduler.cpp b/xllm_service/scheduler/scheduler.cpp index 3f3a57c..b317762 100644 --- a/xllm_service/scheduler/scheduler.cpp +++ b/xllm_service/scheduler/scheduler.cpp @@ -152,6 +152,11 @@ std::vector Scheduler::get_static_decode_list( return instance_mgr_->get_static_decode_list(instance_name); } +std::vector Scheduler::get_static_prefill_list( + const std::string& instance_name) { + return instance_mgr_->get_static_prefill_list(instance_name); +} + Tokenizer* Scheduler::get_tls_tokenizer() { thread_local std::unique_ptr tls_tokenizer(tokenizer_->clone()); return tls_tokenizer.get(); diff --git a/xllm_service/scheduler/scheduler.h b/xllm_service/scheduler/scheduler.h index b999fc8..3852261 100644 --- a/xllm_service/scheduler/scheduler.h +++ b/xllm_service/scheduler/scheduler.h @@ -46,6 +46,9 @@ class Scheduler final { std::vector get_static_decode_list( const std::string& instance_name); + std::vector get_static_prefill_list( + const std::string& instance_name); + void handle_instance_heartbeat(const proto::HeartbeatRequest* req); void exited() { exited_ = true; }