Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,31 @@ void Models::StartModel(
return;
auto config = file_manager_utils::GetCortexConfig();
auto model_handle = (*(req->getJsonObject())).get("model", "").asString();
auto custom_prompt_template =
(*(req->getJsonObject())).get("prompt_template", "").asString();
StartParameterOverride params_override;
if (auto& o = (*(req->getJsonObject()))["prompt_template"]; !o.isNull()) {
params_override.custom_prompt_template = o.asString();
}

if (auto& o = (*(req->getJsonObject()))["cache_enabled"]; !o.isNull()) {
params_override.cache_enabled = o.asBool();
}

if (auto& o = (*(req->getJsonObject()))["ngl"]; !o.isNull()) {
params_override.ngl = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["n_parallel"]; !o.isNull()) {
params_override.n_parallel = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["ctx_len"]; !o.isNull()) {
params_override.ctx_len = o.asInt();
}

if (auto& o = (*(req->getJsonObject()))["cache_type"]; !o.isNull()) {
params_override.cache_type = o.asString();
}

auto model_entry = model_service_->GetDownloadedModel(model_handle);
if (!model_entry.has_value()) {
Json::Value ret;
Expand Down Expand Up @@ -375,9 +398,9 @@ void Models::StartModel(
return;
}

auto result = model_service_->StartModel(
config.apiServerHost, std::stoi(config.apiServerPort), model_handle,
custom_prompt_template);
auto result = model_service_->StartModel(config.apiServerHost,
std::stoi(config.apiServerPort),
model_handle, params_override);
if (result.has_error()) {
Json::Value ret;
ret["message"] = result.error();
Expand Down
20 changes: 16 additions & 4 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ cpp::result<void, std::string> ModelService::DeleteModel(

cpp::result<bool, std::string> ModelService::StartModel(
const std::string& host, int port, const std::string& model_handle,
std::optional<std::string> custom_prompt_template) {
const StartParameterOverride& params_override) {
namespace fs = std::filesystem;
namespace fmu = file_manager_utils;
cortex::db::Models modellist_handler;
Expand Down Expand Up @@ -600,9 +600,9 @@ cpp::result<bool, std::string> ModelService::StartModel(
return false;
}
json_data["model"] = model_handle;
if (!custom_prompt_template.value_or("").empty()) {
auto parse_prompt_result =
string_utils::ParsePrompt(custom_prompt_template.value());
if (auto& cpt = params_override.custom_prompt_template;
!cpt.value_or("").empty()) {
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
json_data["system_prompt"] = parse_prompt_result.system_prompt;
json_data["user_prompt"] = parse_prompt_result.user_prompt;
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
Expand All @@ -612,6 +612,18 @@ cpp::result<bool, std::string> ModelService::StartModel(
json_data["ai_prompt"] = mc.ai_template;
}

#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \
if (param_override.param_name) { \
json_obj[#param_name] = param_override.param_name.value(); \
}

ASSIGN_IF_PRESENT(json_data, params_override, cache_enabled);
ASSIGN_IF_PRESENT(json_data, params_override, ngl);
ASSIGN_IF_PRESENT(json_data, params_override, n_parallel);
ASSIGN_IF_PRESENT(json_data, params_override, ctx_len);
ASSIGN_IF_PRESENT(json_data, params_override, cache_type);
#undef ASSIGN_IF_PRESENT;

CTL_INF(json_data.toStyledString());
assert(!!inference_svc_);
auto ir =
Expand Down
10 changes: 9 additions & 1 deletion engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
#include "services/download_service.h"
#include "services/inference_service.h"

struct StartParameterOverride {
std::optional<bool> cache_enabled;
std::optional<int> ngl;
std::optional<int> n_parallel;
std::optional<int> ctx_len;
std::optional<std::string> custom_prompt_template;
std::optional<std::string> cache_type;
};
class ModelService {
public:
constexpr auto static kHuggingFaceHost = "huggingface.co";
Expand Down Expand Up @@ -46,7 +54,7 @@ class ModelService {

cpp::result<bool, std::string> StartModel(
const std::string& host, int port, const std::string& model_handle,
std::optional<std::string> custom_prompt_template = std::nullopt);
const StartParameterOverride& params_override);

cpp::result<bool, std::string> StopModel(const std::string& host, int port,
const std::string& model_handle);
Expand Down
Loading