diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index be6d45ceb..f812e896d 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -385,7 +385,7 @@ void Models::StartModel( } auto model_entry = model_service_->GetDownloadedModel(model_handle); - if (!model_entry.has_value()) { + if (!model_entry.has_value() && !params_override.bypass_model_check()) { Json::Value ret; ret["message"] = "Cannot find model: " + model_handle; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); @@ -393,7 +393,9 @@ void Models::StartModel( callback(resp); return; } - auto engine_name = model_entry.value().engine; + std::string engine_name = params_override.bypass_model_check() + ? kLlamaEngine + : model_entry.value().engine; auto engine_entry = engine_service_->GetEngineInfo(engine_name); if (engine_entry.has_error()) { Json::Value ret; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 03ea512e3..050dbaa4d 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -15,7 +15,6 @@ #include "utils/logging_utils.h" #include "utils/result.hpp" #include "utils/string_utils.h" -#include "utils/json_helper.h" namespace { void ParseGguf(const DownloadItem& ggufDownloadItem, @@ -577,28 +576,37 @@ cpp::result ModelService::StartModel( config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); - if (model_entry.has_error()) { - CTL_WRN("Error: " + model_entry.error()); - return cpp::fail(model_entry.error()); - } - yaml_handler.ModelConfigFromFile( - fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) - .string()); - auto mc = yaml_handler.GetModelConfig(); - - httplib::Client cli(host + ":" + std::to_string(port)); + Json::Value json_data; + // Currently we don't support download vision models, so we need to bypass check + if (!params_override.bypass_model_check()) { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + if (model_entry.has_error()) { + CTL_WRN("Error: " + model_entry.error()); + return cpp::fail(model_entry.error()); + } + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto mc = yaml_handler.GetModelConfig(); - Json::Value json_data = mc.ToJson(); - if (mc.files.size() > 0) { - // TODO(sang) support multiple files - json_data["model_path"] = - fmu::ToAbsoluteCortexDataPath(fs::path(mc.files[0])).string(); + json_data = mc.ToJson(); + if (mc.files.size() > 0) { + // TODO(sang) support multiple files + json_data["model_path"] = + fmu::ToAbsoluteCortexDataPath(fs::path(mc.files[0])).string(); + } else { + LOG_WARN << "model_path is empty"; + return false; + } + json_data["system_prompt"] = mc.system_template; + json_data["user_prompt"] = mc.user_template; + json_data["ai_prompt"] = mc.ai_template; } else { - LOG_WARN << "model_path is empty"; - return false; + bypass_stop_check_set_.insert(model_handle); } + httplib::Client cli(host + ":" + std::to_string(port)); + json_data["model"] = model_handle; if (auto& cpt = params_override.custom_prompt_template; !cpt.value_or("").empty()) { @@ -606,10 +614,6 @@ cpp::result ModelService::StartModel( json_data["system_prompt"] = parse_prompt_result.system_prompt; json_data["user_prompt"] = parse_prompt_result.user_prompt; json_data["ai_prompt"] = parse_prompt_result.ai_prompt; - } else { - json_data["system_prompt"] = mc.system_template; - json_data["user_prompt"] = mc.user_template; - json_data["ai_prompt"] = mc.ai_template; } #define ASSIGN_IF_PRESENT(json_obj, param_override, param_name) \ @@ -655,22 +659,28 @@ cpp::result ModelService::StopModel( config::YamlHandler yaml_handler; try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); - if (model_entry.has_error()) { - CTL_WRN("Error: " + model_entry.error()); - return cpp::fail(model_entry.error()); + auto bypass_check = (bypass_stop_check_set_.find(model_handle) != + bypass_stop_check_set_.end()); + Json::Value json_data; + if (!bypass_check) { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + if (model_entry.has_error()) { + CTL_WRN("Error: " + model_entry.error()); + return cpp::fail(model_entry.error()); + } + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + auto mc = yaml_handler.GetModelConfig(); + json_data["engine"] = mc.engine; } - yaml_handler.ModelConfigFromFile( - fmu::ToAbsoluteCortexDataPath( - fs::path(model_entry.value().path_to_model_yaml)) - .string()); - auto mc = yaml_handler.GetModelConfig(); httplib::Client cli(host + ":" + std::to_string(port)); - - Json::Value json_data; json_data["model"] = model_handle; - json_data["engine"] = mc.engine; + if (bypass_check) { + json_data["engine"] = kLlamaEngine; + } CTL_INF(json_data.toStyledString()); assert(inference_svc_); auto ir = @@ -678,6 +688,9 @@ cpp::result ModelService::StopModel( auto status = std::get<0>(ir)["status_code"].asInt(); auto data = std::get<1>(ir); if (status == httplib::StatusCode::OK_200) { + if (bypass_check) { + bypass_stop_check_set_.erase(model_handle); + } return true; } else { CTL_ERR("Model failed to stop with status code: " << status); diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 3270542c5..cdae6c6f1 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -8,14 +8,15 @@ #include "services/inference_service.h" struct StartParameterOverride { -std::optional cache_enabled; -std::optional ngl; -std::optional n_parallel; -std::optional ctx_len; -std::optional custom_prompt_template; -std::optional cache_type; -std::optional mmproj; -std::optional model_path; + std::optional cache_enabled; + std::optional ngl; + std::optional n_parallel; + std::optional ctx_len; + std::optional custom_prompt_template; + std::optional cache_type; + std::optional mmproj; + std::optional model_path; + bool bypass_model_check() const { return mmproj.has_value(); } }; class ModelService { public: @@ -86,4 +87,5 @@ class ModelService { std::shared_ptr download_service_; std::shared_ptr inference_svc_; + std::unordered_set bypass_stop_check_set_; };