Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions engine/controllers/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,15 +385,17 @@ void Models::StartModel(
}

auto model_entry = model_service_->GetDownloadedModel(model_handle);
if (!model_entry.has_value()) {
if (!model_entry.has_value() && !params_override.bypass_model_check()) {
Json::Value ret;
ret["message"] = "Cannot find model: " + model_handle;
auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret);
resp->setStatusCode(drogon::k400BadRequest);
callback(resp);
return;
}
auto engine_name = model_entry.value().engine;
std::string engine_name = params_override.bypass_model_check()
? kLlamaEngine
: model_entry.value().engine;
auto engine_entry = engine_service_->GetEngineInfo(engine_name);
if (engine_entry.has_error()) {
Json::Value ret;
Expand Down
85 changes: 49 additions & 36 deletions engine/services/model_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "utils/logging_utils.h"
#include "utils/result.hpp"
#include "utils/string_utils.h"
#include "utils/json_helper.h"

namespace {
void ParseGguf(const DownloadItem& ggufDownloadItem,
Expand Down Expand Up @@ -577,39 +576,44 @@ cpp::result<bool, std::string> ModelService::StartModel(
config::YamlHandler yaml_handler;

try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CTL_WRN("Error: " + model_entry.error());
return cpp::fail(model_entry.error());
}
yaml_handler.ModelConfigFromFile(
fmu::ToAbsoluteCortexDataPath(
fs::path(model_entry.value().path_to_model_yaml))
.string());
auto mc = yaml_handler.GetModelConfig();

httplib::Client cli(host + ":" + std::to_string(port));
Json::Value json_data;
// Currently we don't support download vision models, so we need to bypass check
if (!params_override.bypass_model_check()) {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CTL_WRN("Error: " + model_entry.error());
return cpp::fail(model_entry.error());
}
yaml_handler.ModelConfigFromFile(
fmu::ToAbsoluteCortexDataPath(
fs::path(model_entry.value().path_to_model_yaml))
.string());
auto mc = yaml_handler.GetModelConfig();

Json::Value json_data = mc.ToJson();
if (mc.files.size() > 0) {
// TODO(sang) support multiple files
json_data["model_path"] =
fmu::ToAbsoluteCortexDataPath(fs::path(mc.files[0])).string();
json_data = mc.ToJson();
if (mc.files.size() > 0) {
// TODO(sang) support multiple files
json_data["model_path"] =
fmu::ToAbsoluteCortexDataPath(fs::path(mc.files[0])).string();
} else {
LOG_WARN << "model_path is empty";
return false;
}
json_data["system_prompt"] = mc.system_template;
json_data["user_prompt"] = mc.user_template;
json_data["ai_prompt"] = mc.ai_template;
} else {
LOG_WARN << "model_path is empty";
return false;
bypass_stop_check_set_.insert(model_handle);
}
httplib::Client cli(host + ":" + std::to_string(port));

json_data["model"] = model_handle;
if (auto& cpt = params_override.custom_prompt_template;
!cpt.value_or("").empty()) {
auto parse_prompt_result = string_utils::ParsePrompt(cpt.value());
json_data["system_prompt"] = parse_prompt_result.system_prompt;
json_data["user_prompt"] = parse_prompt_result.user_prompt;
json_data["ai_prompt"] = parse_prompt_result.ai_prompt;
} else {
json_data["system_prompt"] = mc.system_template;
json_data["user_prompt"] = mc.user_template;
json_data["ai_prompt"] = mc.ai_template;
}

#define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \
Expand Down Expand Up @@ -655,29 +659,38 @@ cpp::result<bool, std::string> ModelService::StopModel(
config::YamlHandler yaml_handler;

try {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CTL_WRN("Error: " + model_entry.error());
return cpp::fail(model_entry.error());
auto bypass_check = (bypass_stop_check_set_.find(model_handle) !=
bypass_stop_check_set_.end());
Json::Value json_data;
if (!bypass_check) {
auto model_entry = modellist_handler.GetModelInfo(model_handle);
if (model_entry.has_error()) {
CTL_WRN("Error: " + model_entry.error());
return cpp::fail(model_entry.error());
}
yaml_handler.ModelConfigFromFile(
fmu::ToAbsoluteCortexDataPath(
fs::path(model_entry.value().path_to_model_yaml))
.string());
auto mc = yaml_handler.GetModelConfig();
json_data["engine"] = mc.engine;
}
yaml_handler.ModelConfigFromFile(
fmu::ToAbsoluteCortexDataPath(
fs::path(model_entry.value().path_to_model_yaml))
.string());
auto mc = yaml_handler.GetModelConfig();

httplib::Client cli(host + ":" + std::to_string(port));

Json::Value json_data;
json_data["model"] = model_handle;
json_data["engine"] = mc.engine;
if (bypass_check) {
json_data["engine"] = kLlamaEngine;
}
CTL_INF(json_data.toStyledString());
assert(inference_svc_);
auto ir =
inference_svc_->UnloadModel(std::make_shared<Json::Value>(json_data));
auto status = std::get<0>(ir)["status_code"].asInt();
auto data = std::get<1>(ir);
if (status == httplib::StatusCode::OK_200) {
if (bypass_check) {
bypass_stop_check_set_.erase(model_handle);
}
return true;
} else {
CTL_ERR("Model failed to stop with status code: " << status);
Expand Down
18 changes: 10 additions & 8 deletions engine/services/model_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
#include "services/inference_service.h"

struct StartParameterOverride {
std::optional<bool> cache_enabled;
std::optional<int> ngl;
std::optional<int> n_parallel;
std::optional<int> ctx_len;
std::optional<std::string> custom_prompt_template;
std::optional<std::string> cache_type;
std::optional<std::string> mmproj;
std::optional<std::string> model_path;
std::optional<bool> cache_enabled;
std::optional<int> ngl;
std::optional<int> n_parallel;
std::optional<int> ctx_len;
std::optional<std::string> custom_prompt_template;
std::optional<std::string> cache_type;
std::optional<std::string> mmproj;
std::optional<std::string> model_path;
bool bypass_model_check() const { return mmproj.has_value(); }
};
class ModelService {
public:
Expand Down Expand Up @@ -86,4 +87,5 @@ class ModelService {

std::shared_ptr<DownloadService> download_service_;
std::shared_ptr<services::InferenceService> inference_svc_;
std::unordered_set<std::string> bypass_stop_check_set_;
};
Loading