From effadfb4017e999d26ed00146d95541f2427655a Mon Sep 17 00:00:00 2001 From: James Date: Mon, 16 Sep 2024 02:06:05 +0700 Subject: [PATCH 1/3] feat: download model with direct HF url Signed-off-by: James --- engine/e2e-test/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/e2e-test/main.py b/engine/e2e-test/main.py index 1df424e65..5366c367b 100644 --- a/engine/e2e-test/main.py +++ b/engine/e2e-test/main.py @@ -9,6 +9,7 @@ from test_cli_server_start import TestCliServerStart from test_cortex_update import TestCortexUpdate from test_create_log_folder import TestCreateLogFolder +from test_cli_model_pull_direct_url import TestCliModelPullDirectUrl if __name__ == "__main__": pytest.main([__file__, "-v"]) From d9779ea2eb54992080af2e694715424c327a0bd7 Mon Sep 17 00:00:00 2001 From: James Date: Tue, 17 Sep 2024 01:55:11 +0700 Subject: [PATCH 2/3] feat: add support for hugging face model handle --- engine/services/model_service.cc | 117 +++++++++++-- engine/services/model_service.h | 7 + engine/test/components/CMakeLists.txt | 3 + .../test/components/test_huggingface_utils.cc | 88 ++++++++++ engine/test/components/test_url_parser.cc | 1 - engine/utils/huggingface_utils.h | 158 ++++++++++++++++++ 6 files changed, 362 insertions(+), 12 deletions(-) create mode 100644 engine/test/components/test_huggingface_utils.cc create mode 100644 engine/utils/huggingface_utils.h diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 7943cace4..e1895046d 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -1,12 +1,41 @@ #include "model_service.h" #include #include +#include #include "commands/cmd_info.h" #include "utils/cortexso_parser.h" #include "utils/file_manager_utils.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" #include "utils/model_callback_utils.h" -#include "utils/url_parser.h" + +void PrintMenu(const std::vector& options) { + auto index{1}; + for (const auto& option : options) { + std::cout << index << ". " << option << "\n"; + index++; + } + std::endl(std::cout); +} + +std::optional PrintSelection( + const std::vector& options) { + std::string selection{""}; + PrintMenu(options); + std::cin >> selection; + + if (selection.empty()) { + return std::nullopt; + } + + // std::cout << "Selection: " << selection << "\n"; + // std::cout << "Int representaion: " << std::stoi(selection) << "\n"; + if (std::stoi(selection) > options.size() || std::stoi(selection) < 1) { + return std::nullopt; + } + + return options[std::stoi(selection) - 1]; +} void ModelService::DownloadModel(const std::string& input) { if (input.empty()) { @@ -14,19 +43,65 @@ void ModelService::DownloadModel(const std::string& input) { "Input must be Cortex Model Hub handle or HuggingFace url!"); } - // case input is a direct url - auto url_obj = url_parser::FromUrlString(input); - // TODO: handle case user paste url from cortexso - if (url_obj.protocol == "https") { - if (url_obj.host != kHuggingFaceHost) { - CLI_LOG("Only huggingface.co is supported for now"); + if (input.starts_with("https://")) { + return DownloadModelByDirectUrl(input); + } + + // if input contains / then handle it differently + if (input.find("/") != std::string::npos) { + // TODO: what if we have more than one /? + // TODO: what if the left size of / is cortexso? + + // split by /. TODO: Move this function to somewhere else + std::string model_input = input; + std::string delimiter{"/"}; + std::string token{""}; + std::vector parsed{}; + std::string author{""}; + std::string model_name{""}; + while (token != model_input) { + token = model_input.substr(0, model_input.find_first_of("/")); + model_input = model_input.substr(model_input.find_first_of("/") + 1); + std::string new_str{token}; + parsed.push_back(new_str); + } + + author = parsed[0]; + model_name = parsed[1]; + auto repo_info = + huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); + if (!repo_info.has_value()) { + // throw is better? + CTL_ERR("Model not found"); return; } - return DownloadModelByDirectUrl(input); - } else { - commands::CmdInfo ci(input); - return DownloadModelFromCortexso(ci.model_name, ci.branch); + + if (!repo_info->gguf.has_value()) { + throw std::runtime_error( + "Not a GGUF model. Currently, only GGUF single file is supported."); + } + + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (sibling.rfilename.ends_with(".gguf")) { + options.push_back(sibling.rfilename); + } + } + auto selection = PrintSelection(options); + std::cout << "Selected: " << selection.value() << std::endl; + + auto download_url = huggingface_utils::GetDownloadableUrl( + author, model_name, selection.value()); + + std::cout << "Download url: " << download_url << std::endl; + // TODO: split to this function + // DownloadHuggingFaceGgufModel(author, model_name, nullptr); + return; } + + // user just input a text, seems like a model name only, maybe comes with a branch, using : as delimeter + // handle cortexso here + // separate into another function and the above can route to it if we regconize a cortexso url } std::optional ModelService::GetDownloadedModel( @@ -114,3 +189,23 @@ void ModelService::DownloadModelFromCortexso(const std::string& name, CTL_ERR("Model not found"); } } + +void ModelService::DownloadHuggingFaceGgufModel( + const std::string& author, const std::string& modelName, + std::optional fileName) { + std::cout << author << std::endl; + std::cout << modelName << std::endl; + // if we don't have file name, we must display a list for user to pick + // auto repo_info = + // huggingface_utils::GetHuggingFaceModelRepoInfo(author, modelName); + // + // if (!repo_info.has_value()) { + // // throw is better? + // CTL_ERR("Model not found"); + // return; + // } + // + // for (const auto& sibling : repo_info->siblings) { + // std::cout << sibling.rfilename << "\n"; + // } +} diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 81ec4e4b3..f4ee4e065 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -19,6 +19,13 @@ class ModelService { void DownloadModelFromCortexso(const std::string& name, const std::string& branch); + /** + * Handle downloading model which have following pattern: author/model_name + */ + void DownloadHuggingFaceGgufModel(const std::string& author, + const std::string& modelName, + std::optional fileName); + DownloadService download_service_; constexpr auto static kHuggingFaceHost = "huggingface.co"; diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index db810ad26..fa1c5477e 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -9,9 +9,12 @@ find_package(Drogon CONFIG REQUIRED) find_package(GTest CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) find_package(jinja2cpp CONFIG REQUIRED) +find_package(httplib CONFIG REQUIRED) target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp jinja2cpp ${CMAKE_THREAD_LIBS_INIT}) + +target_link_libraries(${PROJECT_NAME} PRIVATE httplib::httplib) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../) add_test(NAME ${PROJECT_NAME} diff --git a/engine/test/components/test_huggingface_utils.cc b/engine/test/components/test_huggingface_utils.cc new file mode 100644 index 000000000..b1d949d22 --- /dev/null +++ b/engine/test/components/test_huggingface_utils.cc @@ -0,0 +1,88 @@ +#include "gtest/gtest.h" +#include "utils/huggingface_utils.h" + +class HuggingFaceUtilTestSuite : public ::testing::Test {}; + +TEST_F(HuggingFaceUtilTestSuite, TestGetModelRepositoryBranches) { + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", "tinyllama"); + + EXPECT_EQ(branches.size(), 3); + EXPECT_EQ(branches[0].name, "gguf"); + EXPECT_EQ(branches[0].ref, "refs/heads/gguf"); + EXPECT_EQ(branches[1].name, "1b-gguf"); + EXPECT_EQ(branches[1].ref, "refs/heads/1b-gguf"); + EXPECT_EQ(branches[2].name, "main"); + EXPECT_EQ(branches[2].ref, "refs/heads/main"); +} + +TEST_F(HuggingFaceUtilTestSuite, TestGetHuggingFaceModelRepoInfoSuccessfully) { + auto model_info = + huggingface_utils::GetHuggingFaceModelRepoInfo("cortexso", "tinyllama"); + auto not_null = model_info.has_value(); + + EXPECT_TRUE(not_null); + EXPECT_EQ(model_info->id, "cortexso/tinyllama"); + EXPECT_EQ(model_info->modelId, "cortexso/tinyllama"); + EXPECT_EQ(model_info->author, "cortexso"); + EXPECT_EQ(model_info->disabled, false); + EXPECT_EQ(model_info->gated, false); + + auto tag_contains_gguf = + std::find(model_info->tags.begin(), model_info->tags.end(), "gguf") != + model_info->tags.end(); + EXPECT_TRUE(tag_contains_gguf); + + auto contain_gguf_info = model_info->gguf.has_value(); + EXPECT_TRUE(contain_gguf_info); + + auto sibling_not_empty = !model_info->siblings.empty(); + EXPECT_TRUE(sibling_not_empty); +} + +TEST_F(HuggingFaceUtilTestSuite, + TestGetHuggingFaceModelRepoInfoReturnNullGgufInfoWhenNotAGgufModel) { + auto model_info = huggingface_utils::GetHuggingFaceModelRepoInfo( + "BAAI", "bge-reranker-v2-m3"); + auto not_null = model_info.has_value(); + + EXPECT_TRUE(not_null); + EXPECT_EQ(model_info->disabled, false); + EXPECT_EQ(model_info->gated, false); + + auto tag_not_contain_gguf = + std::find(model_info->tags.begin(), model_info->tags.end(), "gguf") == + model_info->tags.end(); + EXPECT_TRUE(tag_not_contain_gguf); + + auto contain_gguf_info = model_info->gguf.has_value(); + EXPECT_TRUE(!contain_gguf_info); + + auto sibling_not_empty = !model_info->siblings.empty(); + EXPECT_TRUE(sibling_not_empty); +} + +TEST_F(HuggingFaceUtilTestSuite, + TestGetHuggingFaceDownloadUrlWithoutBranchName) { + auto downloadable_url = huggingface_utils::GetDownloadableUrl( + "pervll", "bge-reranker-v2-gemma-Q4_K_M-GGUF", + "bge-reranker-v2-gemma-q4_k_m.gguf"); + + auto expected_url{ + "https://huggingface.co/pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF/resolve/" + "main/bge-reranker-v2-gemma-q4_k_m.gguf"}; + + EXPECT_EQ(downloadable_url, expected_url); +} + +TEST_F(HuggingFaceUtilTestSuite, TestGetHuggingFaceDownloadUrlWithBranchName) { + auto downloadable_url = huggingface_utils::GetDownloadableUrl( + "pervll", "bge-reranker-v2-gemma-Q4_K_M-GGUF", + "bge-reranker-v2-gemma-q4_k_m.gguf", "1b-gguf"); + + auto expected_url{ + "https://huggingface.co/pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF/resolve/" + "1b-gguf/bge-reranker-v2-gemma-q4_k_m.gguf"}; + + EXPECT_EQ(downloadable_url, expected_url); +} diff --git a/engine/test/components/test_url_parser.cc b/engine/test/components/test_url_parser.cc index decffcbf8..cee6cb6ed 100644 --- a/engine/test/components/test_url_parser.cc +++ b/engine/test/components/test_url_parser.cc @@ -1,4 +1,3 @@ -#include #include "gtest/gtest.h" #include "utils/url_parser.h" diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h new file mode 100644 index 000000000..3d104dbb4 --- /dev/null +++ b/engine/utils/huggingface_utils.h @@ -0,0 +1,158 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include "utils/json.hpp" +#include "utils/url_parser.h" + +namespace huggingface_utils { + +constexpr static auto kHuggingfaceHost{"huggingface.co"}; + +struct HuggingFaceBranch { + std::string name; + std::string ref; + std::string targetCommit; +}; + +struct HuggingFaceFileSibling { + std::string rfilename; +}; + +struct HuggingFaceGgufInfo { + uint64 total; + std::string architecture; +}; + +struct HuggingFaceModelRepoInfo { + std::string id; + std::string modelId; + std::string author; + std::string sha; + std::string lastModified; + + bool isPrivate; + bool disabled; + bool gated; + std::vector tags; + int downloads; + + int likes; + std::optional gguf; + std::vector siblings; + std::vector spaces; + std::string createdAt; +}; + +inline std::vector GetModelRepositoryBranches( + const std::string& author, const std::string& modelName) { + if (author.empty() || modelName.empty()) { + throw std::runtime_error("Author and model name cannot be empty"); + } + auto url_obj = url_parser::Url{ + .protocol = "https", + .host = kHuggingfaceHost, + .pathParams = {"api", "models", author, modelName, "refs"}}; + + httplib::Client cli(url_obj.GetProtocolAndHost()); + auto res = cli.Get(url_obj.GetPathAndQuery()); + if (res->status != httplib::StatusCode::OK_200) { + throw std::runtime_error( + "Failed to get model repository branches: " + author + "/" + modelName); + } + + using json = nlohmann::json; + auto body = json::parse(res->body); + auto branches_json = body["branches"]; + + std::vector branches{}; + + for (const auto& branch : branches_json) { + branches.push_back(HuggingFaceBranch{ + .name = branch["name"], + .ref = branch["ref"], + .targetCommit = branch["targetCommit"], + }); + } + + return branches; +} + +// only support gguf for now +inline std::optional GetHuggingFaceModelRepoInfo( + const std::string& author, const std::string& modelName) { + if (author.empty() || modelName.empty()) { + throw std::runtime_error("Author and model name cannot be empty"); + } + auto url_obj = + url_parser::Url{.protocol = "https", + .host = kHuggingfaceHost, + .pathParams = {"api", "models", author, modelName}}; + + httplib::Client cli(url_obj.GetProtocolAndHost()); + auto res = cli.Get(url_obj.GetPathAndQuery()); + if (res->status != httplib::StatusCode::OK_200) { + throw std::runtime_error("Failed to get model repository info: " + author + + "/" + modelName); + } + + using json = nlohmann::json; + auto body = json::parse(res->body); + + std::optional gguf = std::nullopt; + auto gguf_info = body["gguf"]; + if (!gguf_info.is_null()) { + gguf = HuggingFaceGgufInfo{ + .total = gguf_info["total"], + .architecture = gguf_info["architecture"], + }; + } + + std::vector siblings{}; + auto siblings_info = body["siblings"]; + for (const auto& sibling : siblings_info) { + auto sibling_info = HuggingFaceFileSibling{ + .rfilename = sibling["rfilename"], + }; + siblings.push_back(sibling_info); + } + + auto model_repo_info = HuggingFaceModelRepoInfo{ + .id = body["id"], + .modelId = body["modelId"], + .author = body["author"], + .sha = body["sha"], + .lastModified = body["lastModified"], + + .isPrivate = body["private"], + .disabled = body["disabled"], + .gated = body["gated"], + .tags = body["tags"], + .downloads = body["downloads"], + + .likes = body["likes"], + .gguf = gguf, + .siblings = siblings, + .spaces = body["spaces"], + .createdAt = body["createdAt"], + }; + + return model_repo_info; +} + +inline std::string GetDownloadableUrl(const std::string& author, + const std::string& modelName, + const std::string& fileName, + const std::string& branch = "main") { + auto url_obj = url_parser::Url{ + .protocol = "https", + .host = kHuggingfaceHost, + .pathParams = {author, modelName, "resolve", branch, fileName}, + }; + return url_parser::FromUrl(url_obj); +} +} // namespace huggingface_utils From 6d090fc349e4f0b4b523621ffa37ac5121fee7cd Mon Sep 17 00:00:00 2001 From: James Date: Wed, 18 Sep 2024 15:34:15 +0700 Subject: [PATCH 3/3] remove redundant import --- engine/commands/model_get_cmd.cc | 11 +- engine/commands/model_get_cmd.h | 11 +- engine/commands/run_cmd.cc | 7 +- engine/controllers/command_line_parser.cc | 8 +- engine/e2e-test/main.py | 1 - ..._cli_model_pull_cortexso_with_selection.py | 12 ++ .../test_cli_model_pull_direct_url.py | 15 +- .../test_cli_model_pull_from_cortexso.py | 11 +- ..._cli_model_pull_hugging_face_repository.py | 28 +++ engine/e2e-test/test_runner.py | 30 +++- engine/services/model_service.cc | 159 +++++++----------- engine/services/model_service.h | 4 +- engine/test/components/test_string_utils.cc | 79 +++++++++ engine/utils/cli_selection_utils.h | 35 ++++ engine/utils/huggingface_utils.h | 3 +- engine/utils/string_utils.h | 32 ++++ 16 files changed, 297 insertions(+), 149 deletions(-) create mode 100644 engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py create mode 100644 engine/e2e-test/test_cli_model_pull_hugging_face_repository.py create mode 100644 engine/test/components/test_string_utils.cc create mode 100644 engine/utils/cli_selection_utils.h create mode 100644 engine/utils/string_utils.h diff --git a/engine/commands/model_get_cmd.cc b/engine/commands/model_get_cmd.cc index acbce742c..d4f19ffa8 100644 --- a/engine/commands/model_get_cmd.cc +++ b/engine/commands/model_get_cmd.cc @@ -4,21 +4,16 @@ #include #include "cmd_info.h" #include "config/yaml_config.h" -#include "trantor/utils/Logger.h" -#include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" namespace commands { -ModelGetCmd::ModelGetCmd(std::string model_handle) - : model_handle_(std::move(model_handle)) {} - -void ModelGetCmd::Exec() { +void ModelGetCmd::Exec(const std::string& model_handle) { auto models_path = file_manager_utils::GetModelsContainerPath(); if (std::filesystem::exists(models_path) && std::filesystem::is_directory(models_path)) { - CmdInfo ci(model_handle_); + CmdInfo ci(model_handle); std::string model_file = ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; bool found_model = false; @@ -149,4 +144,4 @@ void ModelGetCmd::Exec() { CLI_LOG("Model not found!"); } } -}; // namespace commands \ No newline at end of file +}; // namespace commands diff --git a/engine/commands/model_get_cmd.h b/engine/commands/model_get_cmd.h index 9bd9d2213..1836f7d99 100644 --- a/engine/commands/model_get_cmd.h +++ b/engine/commands/model_get_cmd.h @@ -1,17 +1,10 @@ #pragma once - -#include // For std::isnan #include namespace commands { class ModelGetCmd { public: - - ModelGetCmd(std::string model_handle); - void Exec(); - - private: - std::string model_handle_; + void Exec(const std::string& model_handle); }; -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/commands/run_cmd.cc b/engine/commands/run_cmd.cc index 1fb3706d7..cb60822ad 100644 --- a/engine/commands/run_cmd.cc +++ b/engine/commands/run_cmd.cc @@ -2,13 +2,8 @@ #include "chat_cmd.h" #include "cmd_info.h" #include "config/yaml_config.h" -#include "engine_install_cmd.h" -#include "httplib.h" -#include "model_pull_cmd.h" #include "model_start_cmd.h" #include "server_start_cmd.h" -#include "trantor/utils/Logger.h" -#include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" namespace commands { @@ -46,7 +41,7 @@ void RunCmd::Exec() { if (!commands::IsServerAlive(host_, port_)) { CLI_LOG("Starting server ..."); commands::ServerStartCmd ssc; - if(!ssc.Exec(host_, port_)) { + if (!ssc.Exec(host_, port_)) { return; } } diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index b55887ebd..0d74a749b 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -138,10 +138,8 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { models_cmd->add_subcommand("get", "Get info of {model_id} locally"); get_models_cmd->add_option("model_id", model_id, ""); get_models_cmd->require_option(); - get_models_cmd->callback([&model_id]() { - commands::ModelGetCmd command(model_id); - command.Exec(); - }); + get_models_cmd->callback( + [&model_id]() { commands::ModelGetCmd().Exec(model_id); }); auto model_del_cmd = models_cmd->add_subcommand("delete", "Delete a model by ID locally"); @@ -238,7 +236,7 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { auto ps_cmd = app_.add_subcommand("ps", "Show running models and their status"); ps_cmd->group(kSystemGroup); - + CLI11_PARSE(app_, argc, argv); if (argc == 1) { CLI_LOG(app_.help()); diff --git a/engine/e2e-test/main.py b/engine/e2e-test/main.py index 5366c367b..1df424e65 100644 --- a/engine/e2e-test/main.py +++ b/engine/e2e-test/main.py @@ -9,7 +9,6 @@ from test_cli_server_start import TestCliServerStart from test_cortex_update import TestCortexUpdate from test_create_log_folder import TestCreateLogFolder -from test_cli_model_pull_direct_url import TestCliModelPullDirectUrl if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py new file mode 100644 index 000000000..619833e16 --- /dev/null +++ b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py @@ -0,0 +1,12 @@ +from test_runner import popen + + +class TestCliModelPullCortexsoWithSelection: + + def test_pull_model_from_cortexso_should_display_list_and_allow_user_to_choose( + self, + ): + stdout, stderr, return_code = popen(["pull", "tinyllama"], "1\n") + + assert "Model tinyllama downloaded successfully!" in stdout + assert return_code == 0 diff --git a/engine/e2e-test/test_cli_model_pull_direct_url.py b/engine/e2e-test/test_cli_model_pull_direct_url.py index 7d6fa677b..4907ced1f 100644 --- a/engine/e2e-test/test_cli_model_pull_direct_url.py +++ b/engine/e2e-test/test_cli_model_pull_direct_url.py @@ -1,17 +1,18 @@ -import platform - -import pytest from test_runner import run class TestCliModelPullDirectUrl: - @pytest.mark.skipif(True, reason="Expensive test. Only test when needed.") def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( - "Pull model", ["pull", "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"], - timeout=None + "Pull model", + [ + "pull", + "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", + ], + timeout=None, ) assert exit_code == 0, f"Model pull failed with error: {error}" # TODO: verify that the model has been pull successfully - # TODO: skip this test. since download model is taking too long \ No newline at end of file + # TODO: skip this test. since download model is taking too long + diff --git a/engine/e2e-test/test_cli_model_pull_from_cortexso.py b/engine/e2e-test/test_cli_model_pull_from_cortexso.py index e5651ce2f..c9c3f4c40 100644 --- a/engine/e2e-test/test_cli_model_pull_from_cortexso.py +++ b/engine/e2e-test/test_cli_model_pull_from_cortexso.py @@ -1,17 +1,16 @@ -import platform - import pytest from test_runner import run class TestCliModelPullCortexso: - @pytest.mark.skipif(True, reason="Expensive test. Only test when needed.") def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( - "Pull model", ["pull", "tinyllama"], - timeout=None + "Pull model", + ["pull", "tinyllama"], + timeout=None, ) assert exit_code == 0, f"Model pull failed with error: {error}" # TODO: verify that the model has been pull successfully - # TODO: skip this test. since download model is taking too long \ No newline at end of file + # TODO: skip this test. since download model is taking too long + diff --git a/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py new file mode 100644 index 000000000..50b7e832b --- /dev/null +++ b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py @@ -0,0 +1,28 @@ +import pytest +from test_runner import popen + + +class TestCliModelPullHuggingFaceRepository: + + def test_model_pull_hugging_face_repository(self): + """ + Test pull model pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF from issue #1017 + """ + + stdout, stderr, return_code = popen( + ["pull", "pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF"], "1\n" + ) + + assert "downloaded successfully!" in stdout + assert return_code == 0 + + def test_model_pull_hugging_face_not_gguf_should_failed_gracefully(self): + """ + When pull a model which is not GGUF, we stop and show a message to user + """ + + stdout, stderr, return_code = popen(["pull", "BAAI/bge-reranker-v2-m3"], "") + assert ( + "Not a GGUF model. Currently, only GGUF single file is supported." in stdout + ) + assert return_code == 0 diff --git a/engine/e2e-test/test_runner.py b/engine/e2e-test/test_runner.py index dd634d747..7716c55a0 100644 --- a/engine/e2e-test/test_runner.py +++ b/engine/e2e-test/test_runner.py @@ -38,6 +38,26 @@ def run(test_name: str, arguments: List[str], timeout=timeout) -> (int, str, str return result.returncode, result.stdout, result.stderr +def popen(arguments: List[str], user_input: str) -> (int, str, str): + # Start the process + executable_path = getExecutablePath() + process = subprocess.Popen( + [executable_path] + arguments, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, # This ensures the input and output are treated as text + ) + + # Send input and get output + stdout, stderr = process.communicate(input=user_input) + + # Get the return code + return_code = process.returncode + + return stdout, stderr, return_code + + # Start the API server # Wait for `Server started` message or failed def start_server() -> bool: @@ -50,10 +70,10 @@ def start_server() -> bool: def start_server_nix() -> bool: executable = getExecutablePath() process = subprocess.Popen( - [executable] + ['start', '-p', '3928'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True + [executable] + ["start", "-p", "3928"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, ) start_time = time.time() @@ -80,7 +100,7 @@ def start_server_nix() -> bool: def start_server_windows() -> bool: executable = getExecutablePath() process = subprocess.Popen( - [executable] + ['start', '-p', '3928'], + [executable] + ["start", "-p", "3928"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index e1895046d..29575dfab 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -2,40 +2,13 @@ #include #include #include -#include "commands/cmd_info.h" +#include "utils/cli_selection_utils.h" #include "utils/cortexso_parser.h" #include "utils/file_manager_utils.h" #include "utils/huggingface_utils.h" #include "utils/logging_utils.h" #include "utils/model_callback_utils.h" - -void PrintMenu(const std::vector& options) { - auto index{1}; - for (const auto& option : options) { - std::cout << index << ". " << option << "\n"; - index++; - } - std::endl(std::cout); -} - -std::optional PrintSelection( - const std::vector& options) { - std::string selection{""}; - PrintMenu(options); - std::cin >> selection; - - if (selection.empty()) { - return std::nullopt; - } - - // std::cout << "Selection: " << selection << "\n"; - // std::cout << "Int representaion: " << std::stoi(selection) << "\n"; - if (std::stoi(selection) > options.size() || std::stoi(selection) < 1) { - return std::nullopt; - } - - return options[std::stoi(selection) - 1]; -} +#include "utils/string_utils.h" void ModelService::DownloadModel(const std::string& input) { if (input.empty()) { @@ -43,65 +16,49 @@ void ModelService::DownloadModel(const std::string& input) { "Input must be Cortex Model Hub handle or HuggingFace url!"); } - if (input.starts_with("https://")) { + if (string_utils::StartsWith(input, "https://")) { return DownloadModelByDirectUrl(input); } - // if input contains / then handle it differently if (input.find("/") != std::string::npos) { - // TODO: what if we have more than one /? - // TODO: what if the left size of / is cortexso? - - // split by /. TODO: Move this function to somewhere else - std::string model_input = input; - std::string delimiter{"/"}; - std::string token{""}; - std::vector parsed{}; - std::string author{""}; - std::string model_name{""}; - while (token != model_input) { - token = model_input.substr(0, model_input.find_first_of("/")); - model_input = model_input.substr(model_input.find_first_of("/") + 1); - std::string new_str{token}; - parsed.push_back(new_str); + auto parsed = string_utils::SplitBy(input, "/"); + if (parsed.size() != 2) { + throw std::runtime_error("Invalid model handle: " + input); } - author = parsed[0]; - model_name = parsed[1]; - auto repo_info = - huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); - if (!repo_info.has_value()) { - // throw is better? - CTL_ERR("Model not found"); - return; + auto author = parsed[0]; + auto model_name = parsed[1]; + if (author == "cortexso") { + return DownloadModelByModelName(model_name); } - if (!repo_info->gguf.has_value()) { - throw std::runtime_error( - "Not a GGUF model. Currently, only GGUF single file is supported."); - } + DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); + CLI_LOG("Model " << model_name << " downloaded successfully!") + return; + } + + return DownloadModelByModelName(input); +} +void ModelService::DownloadModelByModelName(const std::string& modelName) { + try { + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", modelName); std::vector options{}; - for (const auto& sibling : repo_info->siblings) { - if (sibling.rfilename.ends_with(".gguf")) { - options.push_back(sibling.rfilename); + for (const auto& branch : branches) { + if (branch.name != "main") { + options.emplace_back(branch.name); } } - auto selection = PrintSelection(options); - std::cout << "Selected: " << selection.value() << std::endl; - - auto download_url = huggingface_utils::GetDownloadableUrl( - author, model_name, selection.value()); - - std::cout << "Download url: " << download_url << std::endl; - // TODO: split to this function - // DownloadHuggingFaceGgufModel(author, model_name, nullptr); - return; + if (options.empty()) { + CLI_LOG("No variant found"); + return; + } + auto selection = cli_selection_utils::PrintSelection(options); + DownloadModelFromCortexso(modelName, selection.value()); + } catch (const std::runtime_error& e) { + CLI_LOG("Error downloading model, " << e.what()); } - - // user just input a text, seems like a model name only, maybe comes with a branch, using : as delimeter - // handle cortexso here - // separate into another function and the above can route to it if we regconize a cortexso url } std::optional ModelService::GetDownloadedModel( @@ -131,20 +88,14 @@ std::optional ModelService::GetDownloadedModel( } void ModelService::DownloadModelByDirectUrl(const std::string& url) { - // check for malformed url - // question: What if the url is from cortexso itself - // answer: then route to download from cortexso auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { - // goto hugging face parser to normalize the url - // loop through path params, replace blob to resolve if any if (url_obj.pathParams[2] == "blob") { url_obj.pathParams[2] = "resolve"; } } - // should separate this function out auto model_id{url_obj.pathParams[1]}; auto file_name{url_obj.pathParams.back()}; @@ -161,7 +112,7 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { auto download_url = url_parser::FromUrl(url_obj); // this assume that the model being downloaded is a single gguf file - auto downloadTask{DownloadTask{.id = url_obj.pathParams.back(), + auto downloadTask{DownloadTask{.id = model_id, .type = DownloadType::Model, .items = {DownloadItem{ .id = url_obj.pathParams.back(), @@ -170,7 +121,7 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { }}}}; auto on_finished = [](const DownloadTask& finishedTask) { - std::cout << "Download success" << std::endl; + CLI_LOG("Model " << finishedTask.id << " downloaded successfully!") auto gguf_download_item = finishedTask.items[0]; model_callback_utils::ParseGguf(gguf_download_item); }; @@ -184,7 +135,7 @@ void ModelService::DownloadModelFromCortexso(const std::string& name, if (downloadTask.has_value()) { DownloadService().AddDownloadTask(downloadTask.value(), model_callback_utils::DownloadModelCb); - CTL_INF("Download finished"); + CLI_LOG("Model " << name << " downloaded successfully!") } else { CTL_ERR("Model not found"); } @@ -193,19 +144,29 @@ void ModelService::DownloadModelFromCortexso(const std::string& name, void ModelService::DownloadHuggingFaceGgufModel( const std::string& author, const std::string& modelName, std::optional fileName) { - std::cout << author << std::endl; - std::cout << modelName << std::endl; - // if we don't have file name, we must display a list for user to pick - // auto repo_info = - // huggingface_utils::GetHuggingFaceModelRepoInfo(author, modelName); - // - // if (!repo_info.has_value()) { - // // throw is better? - // CTL_ERR("Model not found"); - // return; - // } - // - // for (const auto& sibling : repo_info->siblings) { - // std::cout << sibling.rfilename << "\n"; - // } + auto repo_info = + huggingface_utils::GetHuggingFaceModelRepoInfo(author, modelName); + if (!repo_info.has_value()) { + // throw is better? + CTL_ERR("Model not found"); + return; + } + + if (!repo_info->gguf.has_value()) { + throw std::runtime_error( + "Not a GGUF model. Currently, only GGUF single file is supported."); + } + + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + options.push_back(sibling.rfilename); + } + } + auto selection = cli_selection_utils::PrintSelection(options); + std::cout << "Selected: " << selection.value() << std::endl; + + auto download_url = huggingface_utils::GetDownloadableUrl(author, modelName, + selection.value()); + DownloadModelByDirectUrl(download_url); } diff --git a/engine/services/model_service.h b/engine/services/model_service.h index f4ee4e065..06212aaee 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -17,7 +17,7 @@ class ModelService { void DownloadModelByDirectUrl(const std::string& url); void DownloadModelFromCortexso(const std::string& name, - const std::string& branch); + const std::string& branch = "main"); /** * Handle downloading model which have following pattern: author/model_name @@ -26,6 +26,8 @@ class ModelService { const std::string& modelName, std::optional fileName); + void DownloadModelByModelName(const std::string& modelName); + DownloadService download_service_; constexpr auto static kHuggingFaceHost = "huggingface.co"; diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc new file mode 100644 index 000000000..7a51b4f58 --- /dev/null +++ b/engine/test/components/test_string_utils.cc @@ -0,0 +1,79 @@ +#include "gtest/gtest.h" +#include "utils/string_utils.h" + +class StringUtilsTestSuite : public ::testing::Test {}; + +TEST_F(StringUtilsTestSuite, TestSplitBy) { + auto input = "this is a test"; + std::string delimiter{' '}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 4); + EXPECT_EQ(result[0], "this"); + EXPECT_EQ(result[1], "is"); + EXPECT_EQ(result[2], "a"); + EXPECT_EQ(result[3], "test"); +} + +TEST_F(StringUtilsTestSuite, TestSplitByWithEmptyString) { + auto input = ""; + std::string delimiter{' '}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 0); +} + +TEST_F(StringUtilsTestSuite, TestSplitModelHandle) { + auto input = "cortexso/tinyllama"; + std::string delimiter{'/'}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result[0], "cortexso"); + EXPECT_EQ(result[1], "tinyllama"); +} + +TEST_F(StringUtilsTestSuite, TestSplitModelHandleWithEmptyModelName) { + auto input = "cortexso/"; + std::string delimiter{'/'}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "cortexso"); +} + +TEST_F(StringUtilsTestSuite, TestStartsWith) { + auto input = "this is a test"; + auto prefix = "this"; + EXPECT_TRUE(string_utils::StartsWith(input, prefix)); +} + +TEST_F(StringUtilsTestSuite, TestStartsWithWithEmptyString) { + auto input = ""; + auto prefix = "this"; + EXPECT_FALSE(string_utils::StartsWith(input, prefix)); +} + +TEST_F(StringUtilsTestSuite, TestStartsWithWithEmptyPrefix) { + auto input = "this is a test"; + auto prefix = ""; + EXPECT_TRUE(string_utils::StartsWith(input, prefix)); +} + +TEST_F(StringUtilsTestSuite, TestEndsWith) { + auto input = "this is a test"; + auto suffix = "test"; + EXPECT_TRUE(string_utils::EndsWith(input, suffix)); +} + +TEST_F(StringUtilsTestSuite, TestEndsWithWithEmptyString) { + auto input = ""; + auto suffix = "test"; + EXPECT_FALSE(string_utils::EndsWith(input, suffix)); +} + +TEST_F(StringUtilsTestSuite, TestEndsWithWithEmptySuffix) { + auto input = "this is a test"; + auto suffix = ""; + EXPECT_TRUE(string_utils::EndsWith(input, suffix)); +} diff --git a/engine/utils/cli_selection_utils.h b/engine/utils/cli_selection_utils.h new file mode 100644 index 000000000..d3848c5bb --- /dev/null +++ b/engine/utils/cli_selection_utils.h @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +namespace cli_selection_utils { +inline void PrintMenu(const std::vector& options) { + auto index{1}; + for (const auto& option : options) { + std::cout << index << ". " << option << "\n"; + index++; + } + std::endl(std::cout); +} + +inline std::optional PrintSelection( + const std::vector& options, + const std::string& title = "Select an option") { + std::cout << title << "\n"; + std::string selection{""}; + PrintMenu(options); + std::cout << "Select an option (" << 1 << "-" << options.size() << "): "; + std::cin >> selection; + + if (selection.empty()) { + return std::nullopt; + } + + if (std::stoi(selection) > options.size() || std::stoi(selection) < 1) { + return std::nullopt; + } + + return options[std::stoi(selection) - 1]; +} +} // namespace cli_selection_utils diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h index 3d104dbb4..3699c38b8 100644 --- a/engine/utils/huggingface_utils.h +++ b/engine/utils/huggingface_utils.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include #include @@ -24,7 +23,7 @@ struct HuggingFaceFileSibling { }; struct HuggingFaceGgufInfo { - uint64 total; + uint64_t total; std::string architecture; }; diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h new file mode 100644 index 000000000..150b8a61f --- /dev/null +++ b/engine/utils/string_utils.h @@ -0,0 +1,32 @@ +#include +#include + +namespace string_utils { +inline bool StartsWith(const std::string& str, const std::string& prefix) { + return str.rfind(prefix, 0) == 0; +} + +inline bool EndsWith(const std::string& str, const std::string& suffix) { + if (str.length() >= suffix.length()) { + return (0 == str.compare(str.length() - suffix.length(), suffix.length(), + suffix)); + } + return false; +} + +inline std::vector SplitBy(const std::string& str, + const std::string& delimiter) { + std::vector tokens; + size_t prev = 0, pos = 0; + do { + pos = str.find(delimiter, prev); + if (pos == std::string::npos) + pos = str.length(); + std::string token = str.substr(prev, pos - prev); + if (!token.empty()) + tokens.push_back(token); + prev = pos + delimiter.length(); + } while (pos < str.length() && prev < str.length()); + return tokens; +} +} // namespace string_utils