diff --git a/engine/commands/model_import_cmd.cc b/engine/commands/model_import_cmd.cc index 193b2488b..3fb047a9d 100644 --- a/engine/commands/model_import_cmd.cc +++ b/engine/commands/model_import_cmd.cc @@ -1,10 +1,8 @@ #include "model_import_cmd.h" #include -#include #include #include "config/gguf_parser.h" #include "config/yaml_config.h" -#include "trantor/utils/Logger.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" #include "utils/modellist_utils.h" @@ -45,7 +43,7 @@ void ModelImportCmd::Exec() { } } catch (const std::exception& e) { - // don't need to remove yml file here, because it's written only if model entry is successfully added, + // don't need to remove yml file here, because it's written only if model entry is successfully added, // remove file here can make it fail with edge case when user try to import new model with existed model_id CLI_LOG("Error importing model path '" + model_path_ + "' with model_id '" + model_handle_ + "': " + e.what()); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 29575dfab..dc6fc3f68 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -95,12 +95,16 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { url_obj.pathParams[2] = "resolve"; } } - + auto author{url_obj.pathParams[0]}; auto model_id{url_obj.pathParams[1]}; auto file_name{url_obj.pathParams.back()}; - auto local_path = - file_manager_utils::GetModelsContainerPath() / model_id / model_id; + if (author == "cortexso") { + return DownloadModelFromCortexso(model_id); + } + + auto local_path{file_manager_utils::GetModelsContainerPath() / + "huggingface.co" / author / model_id / file_name}; try { std::filesystem::create_directories(local_path.parent_path()); @@ -120,10 +124,10 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { .localPath = local_path, }}}}; - auto on_finished = [](const DownloadTask& finishedTask) { + auto on_finished = [&author](const DownloadTask& finishedTask) { CLI_LOG("Model " << finishedTask.id << " downloaded successfully!") auto gguf_download_item = finishedTask.items[0]; - model_callback_utils::ParseGguf(gguf_download_item); + model_callback_utils::ParseGguf(gguf_download_item, author); }; download_service_.AddDownloadTask(downloadTask, on_finished); diff --git a/engine/utils/cortexso_parser.h b/engine/utils/cortexso_parser.h index d4e85bee9..af3372022 100644 --- a/engine/utils/cortexso_parser.h +++ b/engine/utils/cortexso_parser.h @@ -1,5 +1,4 @@ #include -#include #include #include @@ -7,57 +6,57 @@ #include #include "httplib.h" #include "utils/file_manager_utils.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" namespace cortexso_parser { -constexpr static auto kHuggingFaceHost = "https://huggingface.co"; +constexpr static auto kHuggingFaceHost = "huggingface.co"; inline std::optional getDownloadTask( const std::string& modelId, const std::string& branch = "main") { using namespace nlohmann; - std::ostringstream oss; - oss << "/api/models/cortexso/" << modelId << "/tree/" << branch; - const std::string url = oss.str(); + url_parser::Url url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"api", "models", "cortexso", modelId, "tree", branch}}; - std::ostringstream repoAndModelId; - repoAndModelId << "cortexso/" << modelId; - const std::string repoAndModelIdStr = repoAndModelId.str(); - - httplib::Client cli(kHuggingFaceHost); - if (auto res = cli.Get(url)) { + httplib::Client cli(url.GetProtocolAndHost()); + if (auto res = cli.Get(url.GetPathAndQuery())) { if (res->status == httplib::StatusCode::OK_200) { try { auto jsonResponse = json::parse(res->body); - std::vector downloadItems{}; - std::filesystem::path model_container_path = - file_manager_utils::GetModelsContainerPath() / modelId; + std::vector download_items{}; + auto model_container_path = + file_manager_utils::GetModelsContainerPath() / "cortex.so" / + modelId / branch; file_manager_utils::CreateDirectoryRecursively( model_container_path.string()); for (const auto& [key, value] : jsonResponse.items()) { - std::ostringstream downloadUrlOutput; auto path = value["path"].get(); if (path == ".gitattributes" || path == ".gitignore" || path == "README.md") { continue; } - downloadUrlOutput << kHuggingFaceHost << "/" << repoAndModelIdStr - << "/resolve/" << branch << "/" << path; - const std::string download_url = downloadUrlOutput.str(); - auto local_path = model_container_path / path; + url_parser::Url download_url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"cortexso", modelId, "resolve", branch, path}}; - downloadItems.push_back(DownloadItem{.id = path, - .downloadUrl = download_url, - .localPath = local_path}); + auto local_path = model_container_path / path; + download_items.push_back( + DownloadItem{.id = path, + .downloadUrl = download_url.ToFullPath(), + .localPath = local_path}); } - DownloadTask downloadTask{ + DownloadTask download_tasks{ .id = branch == "main" ? modelId : modelId + "-" + branch, .type = DownloadType::Model, - .items = downloadItems}; + .items = download_items}; - return downloadTask; + return download_tasks; } catch (const json::parse_error& e) { CTL_ERR("JSON parse error: {}" << e.what()); } diff --git a/engine/utils/model_callback_utils.h b/engine/utils/model_callback_utils.h index 3a3b0f288..c6e98dd48 100644 --- a/engine/utils/model_callback_utils.h +++ b/engine/utils/model_callback_utils.h @@ -6,27 +6,14 @@ #include "config/gguf_parser.h" #include "config/yaml_config.h" #include "services/download_service.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace model_callback_utils { -inline void WriteYamlOutput(const DownloadItem& modelYmlDownloadItem) { - config::YamlHandler handler; - handler.ModelConfigFromFile(modelYmlDownloadItem.localPath.string()); - config::ModelConfig model_config = handler.GetModelConfig(); - model_config.id = - modelYmlDownloadItem.localPath.parent_path().filename().string(); - - CTL_INF("Updating model config in " - << modelYmlDownloadItem.localPath.string()); - handler.UpdateModelConfig(model_config); - std::string yaml_filename{model_config.id + ".yaml"}; - std::filesystem::path yaml_output = - modelYmlDownloadItem.localPath.parent_path().parent_path() / - yaml_filename; - handler.WriteYamlFile(yaml_output.string()); -} -inline void ParseGguf(const DownloadItem& ggufDownloadItem) { +inline void ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author = nullptr) { config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; gguf_handler.Parse(ggufDownloadItem.localPath.string()); @@ -36,17 +23,27 @@ inline void ParseGguf(const DownloadItem& ggufDownloadItem) { model_config.files = {ggufDownloadItem.localPath.string()}; yaml_handler.UpdateModelConfig(model_config); - std::string yaml_filename{model_config.id + ".yaml"}; - std::filesystem::path yaml_output = - ggufDownloadItem.localPath.parent_path().parent_path() / yaml_filename; - std::filesystem::path yaml_path(ggufDownloadItem.localPath.parent_path() / - "model.yml"); - if (!std::filesystem::exists(yaml_output)) { // if model.yml doesn't exist - yaml_handler.WriteYamlFile(yaml_output.string()); - } + auto yaml_path{ggufDownloadItem.localPath}; + auto yaml_name = yaml_path.replace_extension(".yml"); + if (!std::filesystem::exists(yaml_path)) { yaml_handler.WriteYamlFile(yaml_path.string()); } + + auto url_obj = url_parser::FromUrlString(ggufDownloadItem.downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + + auto author_id = author.has_value() ? author.value() : "cortexso"; + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = model_config.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = yaml_name.string(), + .model_alias = model_config.id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); } inline void DownloadModelCb(const DownloadTask& finishedTask) { @@ -67,12 +64,27 @@ inline void DownloadModelCb(const DownloadTask& finishedTask) { } } - if (model_yml_di != nullptr) { - WriteYamlOutput(*model_yml_di); - } - if (need_parse_gguf && gguf_di != nullptr) { ParseGguf(*gguf_di); } + + if (model_yml_di != nullptr) { + auto url_obj = url_parser::FromUrlString(model_yml_di->downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(model_yml_di->localPath.string()); + auto mc = yaml_handler.GetModelConfig(); + + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = mc.name, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = model_yml_di->localPath.string(), + .model_alias = mc.name, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); + } } } // namespace model_callback_utils diff --git a/engine/utils/modellist_utils.cc b/engine/utils/modellist_utils.cc index 261bf58d5..7e1a43833 100644 --- a/engine/utils/modellist_utils.cc +++ b/engine/utils/modellist_utils.cc @@ -3,10 +3,10 @@ #include #include #include -#include #include #include #include "file_manager_utils.h" + namespace modellist_utils { const std::string ModelListUtils::kModelListPath = (file_manager_utils::GetModelsContainerPath() / @@ -208,7 +208,8 @@ bool ModelListUtils::UpdateModelAlias(const std::string& model_id, }); bool check_alias_unique = std::none_of( entries.begin(), entries.end(), [&](const ModelEntry& entry) { - return (entry.model_id == new_model_alias && entry.model_id != model_id) || + return (entry.model_id == new_model_alias && + entry.model_id != model_id) || entry.model_alias == new_model_alias; }); if (it != entries.end() && check_alias_unique) { @@ -237,4 +238,4 @@ bool ModelListUtils::DeleteModelEntry(const std::string& identifier) { } return false; // Entry not found or not in READY state } -} // namespace modellist_utils \ No newline at end of file +} // namespace modellist_utils diff --git a/engine/utils/modellist_utils.h b/engine/utils/modellist_utils.h index 75a41d880..b7aaca81a 100644 --- a/engine/utils/modellist_utils.h +++ b/engine/utils/modellist_utils.h @@ -1,9 +1,10 @@ #pragma once + #include #include #include #include -#include "logging_utils.h" + namespace modellist_utils { enum class ModelStatus { READY, RUNNING }; @@ -22,7 +23,7 @@ class ModelListUtils { private: mutable std::mutex mutex_; // For thread safety - bool IsUnique(const std::vector& entries, + bool IsUnique(const std::vector& entries, const std::string& model_id, const std::string& model_alias) const; void SaveModelList(const std::vector& entries) const; @@ -40,6 +41,7 @@ class ModelListUtils { bool UpdateModelEntry(const std::string& identifier, const ModelEntry& updated_entry); bool DeleteModelEntry(const std::string& identifier); - bool UpdateModelAlias(const std::string& model_id, const std::string& model_alias); + bool UpdateModelAlias(const std::string& model_id, + const std::string& model_alias); }; -} // namespace modellist_utils \ No newline at end of file +} // namespace modellist_utils diff --git a/engine/utils/url_parser.h b/engine/utils/url_parser.h index 6a6e01179..97d499a97 100644 --- a/engine/utils/url_parser.h +++ b/engine/utils/url_parser.h @@ -54,6 +54,10 @@ struct Url { } return path; }; + + std::string ToFullPath() const { + return GetProtocolAndHost() + GetPathAndQuery(); + } }; const std::regex url_regex(