diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index 5daf02f2a..8f2549dcb 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -1,11 +1,23 @@ #include "model_start_cmd.h" +#include "config/yaml_config.h" +#include "cortex_upd_cmd.h" +#include "database/models.h" #include "httplib.h" +#include "run_cmd.h" #include "server_start_cmd.h" +#include "utils/cli_selection_utils.h" #include "utils/logging_utils.h" namespace commands { bool ModelStartCmd::Exec(const std::string& host, int port, const std::string& model_handle) { + std::optional model_id = + SelectLocalModel(model_service_, model_handle); + + if(!model_id.has_value()) { + return false; + } + // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); @@ -17,14 +29,16 @@ bool ModelStartCmd::Exec(const std::string& host, int port, // Call API to start model httplib::Client cli(host + ":" + std::to_string(port)); Json::Value json_data; - json_data["model"] = model_handle; + json_data["model"] = model_id.value(); auto data_str = json_data.toStyledString(); cli.set_read_timeout(std::chrono::seconds(60)); auto res = cli.Post("/v1/models/start", httplib::Headers(), data_str.data(), data_str.size(), "application/json"); if (res) { if (res->status == httplib::StatusCode::OK_200) { - CLI_LOG("Model loaded!"); + CLI_LOG(model_id.value() << " model started successfully. Use `" + << commands::GetCortexBinary() << " run " + << *model_id << "` for interactive chat shell"); return true; } else { CTL_ERR("Model failed to load with status code: " << res->status); diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 3f501fdbb..73aa5c362 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -10,6 +10,59 @@ namespace commands { +std::optional SelectLocalModel(ModelService& model_service, + const std::string& model_handle) { + std::optional model_id = model_handle; + cortex::db::Models modellist_handler; + + if (model_handle.empty()) { + auto all_local_models = modellist_handler.LoadModelList(); + if (all_local_models.has_error() || all_local_models.value().empty()) { + CLI_LOG("No local models available!"); + return std::nullopt; + } + + if (all_local_models.value().size() == 1) { + model_id = all_local_models.value().front().model; + } else { + std::vector model_id_list{}; + for (const auto& model : all_local_models.value()) { + model_id_list.push_back(model.model); + } + + auto selection = cli_selection_utils::PrintSelection( + model_id_list, "Please select an option"); + if (!selection.has_value()) { + return std::nullopt; + } + model_id = selection.value(); + CLI_LOG("Selected: " << selection.value()); + } + } else { + auto related_models_ids = modellist_handler.FindRelatedModel(model_handle); + if (related_models_ids.has_error() || related_models_ids.value().empty()) { + auto result = model_service.DownloadModel(model_handle); + if (result.has_error()) { + CLI_LOG("Model " << model_handle << " not found!"); + return std::nullopt; + } + model_id = result.value(); + CTL_INF("model_id: " << model_id.value()); + } else if (related_models_ids.value().size() == 1) { + model_id = related_models_ids.value().front(); + } else { // multiple models with nearly same name found + auto selection = cli_selection_utils::PrintSelection( + related_models_ids.value(), "Local Models: (press enter to select)"); + if (!selection.has_value()) { + return std::nullopt; + } + model_id = selection.value(); + CLI_LOG("Selected: " << selection.value()); + } + } + return model_id; +} + namespace { std::string Repo2Engine(const std::string& r) { if (r == kLlamaRepo) { @@ -24,63 +77,16 @@ std::string Repo2Engine(const std::string& r) { } // namespace void RunCmd::Exec(bool run_detach) { - std::optional model_id = model_handle_; - + std::optional model_id = + SelectLocalModel(model_service_, model_handle_); + if (!model_id.has_value()) { + return; + } + cortex::db::Models modellist_handler; config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); - { - if (model_handle_.empty()) { - auto all_local_models = modellist_handler.LoadModelList(); - if (all_local_models.has_error() || all_local_models.value().empty()) { - CLI_LOG("No local models available!"); - return; - } - - if (all_local_models.value().size() == 1) { - model_id = all_local_models.value().front().model; - } else { - std::vector model_id_list{}; - for (const auto& model : all_local_models.value()) { - model_id_list.push_back(model.model); - } - - auto selection = cli_selection_utils::PrintSelection( - model_id_list, "Please select an option"); - if (!selection.has_value()) { - return; - } - model_id = selection.value(); - CLI_LOG("Selected: " << selection.value()); - } - } else { - auto related_models_ids = - modellist_handler.FindRelatedModel(model_handle_); - if (related_models_ids.has_error() || - related_models_ids.value().empty()) { - auto result = model_service_.DownloadModel(model_handle_); - if (result.has_error()) { - CLI_LOG("Model " << model_handle_ << " not found!"); - return; - } - model_id = result.value(); - CTL_INF("model_id: " << model_id.value()); - } else if (related_models_ids.value().size() == 1) { - model_id = related_models_ids.value().front(); - } else { // multiple models with nearly same name found - auto selection = cli_selection_utils::PrintSelection( - related_models_ids.value(), - "Local Models: (press enter to select)"); - if (!selection.has_value()) { - return; - } - model_id = selection.value(); - CLI_LOG("Selected: " << selection.value()); - } - } - } - try { namespace fs = std::filesystem; namespace fmu = file_manager_utils; @@ -148,7 +154,7 @@ void RunCmd::Exec(bool run_detach) { // Chat if (run_detach) { CLI_LOG(*model_id << " model started successfully. Use `" - << commands::GetCortexBinary() << " chat " << *model_id + << commands::GetCortexBinary() << " run " << *model_id << "` for interactive chat shell"); } else { ChatCompletionCmd(model_service_).Exec(host_, port_, *model_id, mc, ""); diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index b035a54d9..4a0d68078 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -5,6 +5,10 @@ #include "services/model_service.h" namespace commands { + +std::optional SelectLocalModel(ModelService& model_service, + const std::string& model_handle); + class RunCmd { public: explicit RunCmd(std::string host, int port, std::string model_handle,