From 9170ca7f670e8ff0c1f5c7027f267d80cb77a3b2 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 24 Oct 2024 11:00:34 +0700 Subject: [PATCH 1/2] feat: add customize parameters to /v1/models/start --- engine/controllers/models.cc | 33 +++++++++++++++++++++++++++----- engine/services/model_service.cc | 20 +++++++++++++++---- engine/services/model_service.h | 10 +++++++++- 3 files changed, 53 insertions(+), 10 deletions(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 174c89184..45e213e59 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -345,8 +345,31 @@ void Models::StartModel( return; auto config = file_manager_utils::GetCortexConfig(); auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); - auto custom_prompt_template = - (*(req->getJsonObject())).get("prompt_template", "").asString(); + StartParameterOverride params_override; + if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) { + params_override.custom_prompt_template = o.asString(); + } + + if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) { + params_override.cache_enabled = o.asBool(); + } + + if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) { + o.asInt(); + } + + if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) { + params_override.n_parallel = o.asInt(); + } + + if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) { + params_override.ctx_len = o.asInt(); + } + + if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) { + params_override.cache_type = o.asString(); + } + auto model_entry = model_service_->GetDownloadedModel(model_handle); if (!model_entry.has_value()) { Json::Value ret; @@ -375,9 +398,9 @@ void Models::StartModel( return; } - auto result = model_service_->StartModel( - config.apiServerHost, std::stoi(config.apiServerPort), model_handle, - custom_prompt_template); + auto result = model_service_->StartModel(config.apiServerHost, + std::stoi(config.apiServerPort), + model_handle, params_override); if (result.has_error()) { Json::Value ret; ret["message"] = result.error(); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index ae3316c12..c79e3c4f1 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -570,7 +570,7 @@ cpp::result ModelService::DeleteModel( cpp::result ModelService::StartModel( const std::string& host, int port, const std::string& model_handle, - std::optional custom_prompt_template) { + const StartParameterOverride& params_override) { namespace fs = std::filesystem; namespace fmu = file_manager_utils; cortex::db::Models modellist_handler; @@ -600,9 +600,9 @@ cpp::result ModelService::StartModel( return false; } json_data["model"] = model_handle; - if (!custom_prompt_template.value_or("").empty()) { - auto parse_prompt_result = - string_utils::ParsePrompt(custom_prompt_template.value()); + if (auto& cpt = params_override.custom_prompt_template; + !cpt.value_or("").empty()) { + auto parse_prompt_result = string_utils::ParsePrompt(cpt.value()); json_data["system_prompt"] = parse_prompt_result.system_prompt; json_data["user_prompt"] = parse_prompt_result.user_prompt; json_data["ai_prompt"] = parse_prompt_result.ai_prompt; @@ -612,6 +612,18 @@ cpp::result ModelService::StartModel( json_data["ai_prompt"] = mc.ai_template; } +#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \ + if (param_override.param_name) { \ + json_obj[#param_name] = param_override.param_name.value(); \ + } + + ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled); + ASSIGN_IF_PRESENT(json_data, params_override, ngl); + ASSIGN_IF_PRESENT(json_data, params_override, n_parallel); + ASSIGN_IF_PRESENT(json_data, params_override, ctx_len); + ASSIGN_IF_PRESENT(json_data, params_override, cache_type); +#undef ASSIGN_IF_PRESENT; + CTL_INF(json_data.toStyledString()); assert(!!inference_svc_); auto ir = diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 5adc5a01e..4e6101eaa 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -7,6 +7,14 @@ #include "services/download_service.h" #include "services/inference_service.h" +struct StartParameterOverride { +std::optional cache_enabled; +std::optional ngl; +std::optional n_parallel; +std::optional ctx_len; +std::optional custom_prompt_template; +std::optional cache_type; +}; class ModelService { public: constexpr auto static kHuggingFaceHost = "huggingface.co"; @@ -46,7 +54,7 @@ class ModelService { cpp::result StartModel( const std::string& host, int port, const std::string& model_handle, - std::optional custom_prompt_template = std::nullopt); + const StartParameterOverride& params_override); cpp::result StopModel(const std::string& host, int port, const std::string& model_handle); From fe4be3474c658241227e7fe63ece7960420683c4 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 24 Oct 2024 11:03:30 +0700 Subject: [PATCH 2/2] fix: ngl --- engine/controllers/models.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 45e213e59..0dac515fc 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -355,7 +355,7 @@ void Models::StartModel( } if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) { - o.asInt(); + params_override.ngl = o.asInt(); } if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) { @@ -369,7 +369,7 @@ void Models::StartModel( if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) { params_override.cache_type = o.asString(); } - + auto model_entry = model_service_->GetDownloadedModel(model_handle); if (!model_entry.has_value()) { Json::Value ret;