diff --git a/engine/commands/model_import_cmd.cc b/engine/commands/model_import_cmd.cc new file mode 100644 index 000000000..830a1fdd7 --- /dev/null +++ b/engine/commands/model_import_cmd.cc @@ -0,0 +1,53 @@ +#include "model_import_cmd.h" +#include +#include +#include +#include "config/gguf_parser.h" +#include "config/yaml_config.h" +#include "trantor/utils/Logger.h" +#include "utils/file_manager_utils.h" +#include "utils/logging_utils.h" +#include "utils/modellist_utils.h" + +namespace commands { + +ModelImportCmd::ModelImportCmd(std::string model_handle, std::string model_path) + : model_handle_(std::move(model_handle)), + model_path_(std::move(model_path)) {} + +void ModelImportCmd::Exec() { + config::GGUFHandler gguf_handler; + config::YamlHandler yaml_handler; + modellist_utils::ModelListUtils modellist_utils_obj; + + std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / + std::filesystem::path("imported") / + std::filesystem::path(model_handle_ + ".yml")) + .string(); + modellist_utils::ModelEntry model_entry{ + model_handle_, "local", "imported", + model_yaml_path, model_handle_, modellist_utils::ModelStatus::READY}; + try { + std::filesystem::create_directories( + std::filesystem::path(model_yaml_path).parent_path()); + gguf_handler.Parse(model_path_); + auto model_config = gguf_handler.GetModelConfig(); + model_config.files.push_back(model_path_); + model_config.model = model_handle_; + yaml_handler.UpdateModelConfig(model_config); + + if (modellist_utils_obj.AddModelEntry(model_entry)) { + yaml_handler.WriteYamlFile(model_yaml_path); + CLI_LOG("Model is imported successfully!"); + } else { + CLI_LOG("Fail to import model, model_id '" + model_handle_ + + "' already exists!"); + } + + } catch (const std::exception& e) { + std::remove(model_yaml_path.c_str()); + CLI_LOG("Error importing model path '" + model_path_ + "' with model_id '" + + model_handle_ + "': " + e.what()); + } +} +} // namespace commands diff --git a/engine/commands/model_import_cmd.h b/engine/commands/model_import_cmd.h new file mode 100644 index 000000000..d4248281f --- /dev/null +++ b/engine/commands/model_import_cmd.h @@ -0,0 +1,15 @@ +#pragma once + +#include +namespace commands { + +class ModelImportCmd { + public: + ModelImportCmd(std::string model_handle, std::string model_path); + void Exec(); + + private: + std::string model_handle_; + std::string model_path_; +}; +} // namespace commands \ No newline at end of file diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index 6357a11ff..25557fd41 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -9,6 +9,7 @@ #include "commands/model_alias_cmd.h" #include "commands/model_del_cmd.h" #include "commands/model_get_cmd.h" +#include "commands/model_import_cmd.h" #include "commands/model_list_cmd.h" #include "commands/model_pull_cmd.h" #include "commands/model_start_cmd.h" @@ -166,6 +167,19 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { auto model_update_cmd = models_cmd->add_subcommand("update", "Update configuration of a model"); + std::string model_path; + auto model_import_cmd = models_cmd->add_subcommand( + "import", "Import a gguf model from local file"); + model_import_cmd->add_option("--model_id", model_id, ""); + model_import_cmd->add_option("--model_path", model_path, + "Absolute path to .gguf model, the path should " + "include the gguf file name"); + model_import_cmd->require_option(2); + model_import_cmd->callback([&model_id,&model_path]() { + commands::ModelImportCmd command(model_id, model_path); + command.Exec(); + }); + // Default version is latest std::string version{"latest"}; // engines group commands diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 45a5bf60e..2d2434d6d 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -194,6 +194,78 @@ void Models::DeleteModel(const HttpRequestPtr& req, } } +void Models::ImportModel( + const HttpRequestPtr& req, + std::function&& callback) const { + if (!http_util::HasFieldInReq(req, callback, "modelId") || + !http_util::HasFieldInReq(req, callback, "modelPath")) { + return; + } + auto modelHandle = (*(req->getJsonObject())).get("modelId", "").asString(); + auto modelPath = (*(req->getJsonObject())).get("modelPath", "").asString(); + config::GGUFHandler gguf_handler; + config::YamlHandler yaml_handler; + modellist_utils::ModelListUtils modellist_utils_obj; + + std::string model_yaml_path = (file_manager_utils::GetModelsContainerPath() / + std::filesystem::path("imported") / + std::filesystem::path(modelHandle + ".yml")) + .string(); + modellist_utils::ModelEntry model_entry{ + modelHandle, "local", "imported", + model_yaml_path, modelHandle, modellist_utils::ModelStatus::READY}; + try { + std::filesystem::create_directories( + std::filesystem::path(model_yaml_path).parent_path()); + gguf_handler.Parse(modelPath); + config::ModelConfig model_config = gguf_handler.GetModelConfig(); + model_config.files.push_back(modelPath); + model_config.name = modelHandle; + yaml_handler.UpdateModelConfig(model_config); + + if (modellist_utils_obj.AddModelEntry(model_entry)) { + yaml_handler.WriteYamlFile(model_yaml_path); + std::string success_message = "Model is imported successfully!"; + LOG_INFO << success_message; + Json::Value ret; + ret["result"] = "OK"; + ret["modelHandle"] = modelHandle; + ret["message"] = success_message; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + + } else { + std::string error_message = "Fail to import model, model_id '" + + modelHandle + "' already exists!"; + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Import failed!"; + ret["modelHandle"] = modelHandle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } + + } catch (const std::exception& e) { + std::remove(model_yaml_path.c_str()); + std::string error_message = "Error importing model path '" + modelPath + + "' with model_id '" + modelHandle + + "': " + e.what(); + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Import failed!"; + ret["modelHandle"] = modelHandle; + ret["message"] = error_message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } +} + void Models::SetModelAlias( const HttpRequestPtr& req, std::function&& callback) const { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 65a3641f3..4ae1ff41f 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -15,6 +15,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::PullModel, "/pull", Post); METHOD_ADD(Models::ListModel, "/list", Get); METHOD_ADD(Models::GetModel, "/get", Post); + METHOD_ADD(Models::ImportModel, "/import", Post); METHOD_ADD(Models::DeleteModel, "/{1}", Delete); METHOD_ADD(Models::SetModelAlias, "/alias", Post); METHOD_LIST_END @@ -25,6 +26,8 @@ class Models : public drogon::HttpController { std::function&& callback) const; void GetModel(const HttpRequestPtr& req, std::function&& callback) const; + void ImportModel(const HttpRequestPtr& req, + std::function&& callback) const; void DeleteModel(const HttpRequestPtr& req, std::function&& callback, const std::string& model_id) const; diff --git a/engine/e2e-test/main.py b/engine/e2e-test/main.py index 1df424e65..f5a1c65ff 100644 --- a/engine/e2e-test/main.py +++ b/engine/e2e-test/main.py @@ -9,6 +9,7 @@ from test_cli_server_start import TestCliServerStart from test_cortex_update import TestCortexUpdate from test_create_log_folder import TestCreateLogFolder +from test_cli_model_import import TestCliModelImport if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/engine/e2e-test/test_cli_model_import.py b/engine/e2e-test/test_cli_model_import.py new file mode 100644 index 000000000..1f54ae511 --- /dev/null +++ b/engine/e2e-test/test_cli_model_import.py @@ -0,0 +1,14 @@ +import pytest +from test_runner import run + +class TestCliModelImport: + + @pytest.mark.skipif(True, reason="Expensive test. Only test when you have local gguf file.") + def test_model_import_should_be_success(self): + + exit_code, output, error = run( + "Pull model", ["models", "import", "--model_id","test_model","--model_path","/path/to/local/gguf"], + timeout=None + ) + assert exit_code == 0, f"Model import failed failed with error: {error}" + # TODO: skip this test. since download model is taking too long \ No newline at end of file