From be385b7142b2d3c8e2b589a84963e6be4f3d8a99 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 16 Jan 2025 10:11:52 +0700 Subject: [PATCH] fix: check model status before inferencing --- engine/services/inference_service.cc | 82 +++++++++++++++++----------- engine/services/inference_service.h | 4 +- 2 files changed, 52 insertions(+), 34 deletions(-) diff --git a/engine/services/inference_service.cc b/engine/services/inference_service.cc index 057b6f716..4ea9ebdfd 100644 --- a/engine/services/inference_service.cc +++ b/engine/services/inference_service.cc @@ -14,6 +14,21 @@ cpp::result InferenceService::HandleChatCompletion( } function_calling_utils::PreprocessRequest(json_body); auto tool_choice = json_body->get("tool_choice", Json::Value::null); + auto model_id = json_body->get("model", "").asString(); + if (saved_models_.find(model_id) != saved_models_.end()) { + // check if model is started, if not start it first + Json::Value root; + root["model"] = model_id; + root["engine"] = engine_type; + auto ir = GetModelStatus(std::make_shared(root)); + auto status = std::get<0>(ir)["status_code"].asInt(); + if (status != drogon::k200OK) { + CTL_INF("Model is not loaded, start loading it: " << model_id); + auto res = LoadModel(saved_models_.at(model_id)); + // ignore return result + } + } + auto engine_result = engine_service_->GetLoadedEngine(engine_type); if (engine_result.has_error()) { Json::Value res; @@ -23,45 +38,42 @@ cpp::result InferenceService::HandleChatCompletion( LOG_WARN << "Engine is not loaded yet"; return cpp::fail(std::make_pair(stt, res)); } + + if (!model_id.empty()) { + if (auto model_service = model_service_.lock()) { + auto metadata_ptr = model_service->GetCachedModelMetadata(model_id); + if (metadata_ptr != nullptr && + !metadata_ptr->tokenizer->chat_template.empty()) { + auto tokenizer = metadata_ptr->tokenizer; + auto messages = (*json_body)["messages"]; + Json::Value messages_jsoncpp(Json::arrayValue); + for (auto message : messages) { + messages_jsoncpp.append(message); + } - { - auto model_id = json_body->get("model", "").asString(); - if (!model_id.empty()) { - if (auto model_service = model_service_.lock()) { - auto metadata_ptr = model_service->GetCachedModelMetadata(model_id); - if (metadata_ptr != nullptr && - !metadata_ptr->tokenizer->chat_template.empty()) { - auto tokenizer = metadata_ptr->tokenizer; - auto messages = (*json_body)["messages"]; - Json::Value messages_jsoncpp(Json::arrayValue); - for (auto message : messages) { - messages_jsoncpp.append(message); - } - - Json::Value tools(Json::arrayValue); - Json::Value template_data_json; - template_data_json["messages"] = messages_jsoncpp; - // template_data_json["tools"] = tools; - - auto prompt_result = jinja::RenderTemplate( - tokenizer->chat_template, template_data_json, - tokenizer->bos_token, tokenizer->eos_token, - tokenizer->add_bos_token, tokenizer->add_eos_token, - tokenizer->add_generation_prompt); - if (prompt_result.has_value()) { - (*json_body)["prompt"] = prompt_result.value(); - Json::Value stops(Json::arrayValue); - stops.append(tokenizer->eos_token); - (*json_body)["stop"] = stops; - } else { - CTL_ERR("Failed to render prompt: " + prompt_result.error()); - } + Json::Value tools(Json::arrayValue); + Json::Value template_data_json; + template_data_json["messages"] = messages_jsoncpp; + // template_data_json["tools"] = tools; + + auto prompt_result = jinja::RenderTemplate( + tokenizer->chat_template, template_data_json, tokenizer->bos_token, + tokenizer->eos_token, tokenizer->add_bos_token, + tokenizer->add_eos_token, tokenizer->add_generation_prompt); + if (prompt_result.has_value()) { + (*json_body)["prompt"] = prompt_result.value(); + Json::Value stops(Json::arrayValue); + stops.append(tokenizer->eos_token); + (*json_body)["stop"] = stops; + } else { + CTL_ERR("Failed to render prompt: " + prompt_result.error()); } } } } - CTL_INF("Json body inference: " + json_body->toStyledString()); + + CTL_DBG("Json body inference: " + json_body->toStyledString()); auto cb = [q, tool_choice](Json::Value status, Json::Value res) { if (!tool_choice.isNull()) { @@ -205,6 +217,10 @@ InferResult InferenceService::LoadModel( std::get(engine_result.value()) ->LoadModel(json_body, std::move(cb)); } + if (!engine_service_->IsRemoteEngine(engine_type)) { + auto model_id = json_body->get("model", "").asString(); + saved_models_[model_id] = json_body; + } return std::make_pair(stt, r); } diff --git a/engine/services/inference_service.h b/engine/services/inference_service.h index 794110f99..726275bba 100644 --- a/engine/services/inference_service.h +++ b/engine/services/inference_service.h @@ -47,7 +47,7 @@ class InferenceService { cpp::result HandleRouteRequest( std::shared_ptr q, std::shared_ptr json_body); - + InferResult LoadModel(std::shared_ptr json_body); InferResult UnloadModel(const std::string& engine, @@ -74,4 +74,6 @@ class InferenceService { private: std::shared_ptr engine_service_; std::weak_ptr model_service_; + using SavedModel = std::shared_ptr; + std::unordered_map saved_models_; };