diff --git a/engine/commands/model_get_cmd.cc b/engine/commands/model_get_cmd.cc index acbce742c..d4f19ffa8 100644 --- a/engine/commands/model_get_cmd.cc +++ b/engine/commands/model_get_cmd.cc @@ -4,21 +4,16 @@ #include #include "cmd_info.h" #include "config/yaml_config.h" -#include "trantor/utils/Logger.h" -#include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" namespace commands { -ModelGetCmd::ModelGetCmd(std::string model_handle) - : model_handle_(std::move(model_handle)) {} - -void ModelGetCmd::Exec() { +void ModelGetCmd::Exec(const std::string& model_handle) { auto models_path = file_manager_utils::GetModelsContainerPath(); if (std::filesystem::exists(models_path) && std::filesystem::is_directory(models_path)) { - CmdInfo ci(model_handle_); + CmdInfo ci(model_handle); std::string model_file = ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; bool found_model = false; @@ -149,4 +144,4 @@ void ModelGetCmd::Exec() { CLI_LOG("Model not found!"); } } -}; // namespace commands \ No newline at end of file +}; // namespace commands diff --git a/engine/commands/model_get_cmd.h b/engine/commands/model_get_cmd.h index 9bd9d2213..1836f7d99 100644 --- a/engine/commands/model_get_cmd.h +++ b/engine/commands/model_get_cmd.h @@ -1,17 +1,10 @@ #pragma once - -#include // For std::isnan #include namespace commands { class ModelGetCmd { public: - - ModelGetCmd(std::string model_handle); - void Exec(); - - private: - std::string model_handle_; + void Exec(const std::string& model_handle); }; -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/commands/run_cmd.cc b/engine/commands/run_cmd.cc index 1fb3706d7..cb60822ad 100644 --- a/engine/commands/run_cmd.cc +++ b/engine/commands/run_cmd.cc @@ -2,13 +2,8 @@ #include "chat_cmd.h" #include "cmd_info.h" #include "config/yaml_config.h" -#include "engine_install_cmd.h" -#include "httplib.h" -#include "model_pull_cmd.h" #include "model_start_cmd.h" #include "server_start_cmd.h" -#include "trantor/utils/Logger.h" -#include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" namespace commands { @@ -46,7 +41,7 @@ void RunCmd::Exec() { if (!commands::IsServerAlive(host_, port_)) { CLI_LOG("Starting server ..."); commands::ServerStartCmd ssc; - if(!ssc.Exec(host_, port_)) { + if (!ssc.Exec(host_, port_)) { return; } } diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index b55887ebd..0d74a749b 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -138,10 +138,8 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { models_cmd->add_subcommand("get", "Get info of {model_id} locally"); get_models_cmd->add_option("model_id", model_id, ""); get_models_cmd->require_option(); - get_models_cmd->callback([&model_id]() { - commands::ModelGetCmd command(model_id); - command.Exec(); - }); + get_models_cmd->callback( + [&model_id]() { commands::ModelGetCmd().Exec(model_id); }); auto model_del_cmd = models_cmd->add_subcommand("delete", "Delete a model by ID locally"); @@ -238,7 +236,7 @@ bool CommandLineParser::SetupCommand(int argc, char** argv) { auto ps_cmd = app_.add_subcommand("ps", "Show running models and their status"); ps_cmd->group(kSystemGroup); - + CLI11_PARSE(app_, argc, argv); if (argc == 1) { CLI_LOG(app_.help()); diff --git a/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py new file mode 100644 index 000000000..619833e16 --- /dev/null +++ b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py @@ -0,0 +1,12 @@ +from test_runner import popen + + +class TestCliModelPullCortexsoWithSelection: + + def test_pull_model_from_cortexso_should_display_list_and_allow_user_to_choose( + self, + ): + stdout, stderr, return_code = popen(["pull", "tinyllama"], "1\n") + + assert "Model tinyllama downloaded successfully!" in stdout + assert return_code == 0 diff --git a/engine/e2e-test/test_cli_model_pull_direct_url.py b/engine/e2e-test/test_cli_model_pull_direct_url.py index 7d6fa677b..4907ced1f 100644 --- a/engine/e2e-test/test_cli_model_pull_direct_url.py +++ b/engine/e2e-test/test_cli_model_pull_direct_url.py @@ -1,17 +1,18 @@ -import platform - -import pytest from test_runner import run class TestCliModelPullDirectUrl: - @pytest.mark.skipif(True, reason="Expensive test. Only test when needed.") def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( - "Pull model", ["pull", "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf"], - timeout=None + "Pull model", + [ + "pull", + "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", + ], + timeout=None, ) assert exit_code == 0, f"Model pull failed with error: {error}" # TODO: verify that the model has been pull successfully - # TODO: skip this test. since download model is taking too long \ No newline at end of file + # TODO: skip this test. since download model is taking too long + diff --git a/engine/e2e-test/test_cli_model_pull_from_cortexso.py b/engine/e2e-test/test_cli_model_pull_from_cortexso.py index e5651ce2f..c9c3f4c40 100644 --- a/engine/e2e-test/test_cli_model_pull_from_cortexso.py +++ b/engine/e2e-test/test_cli_model_pull_from_cortexso.py @@ -1,17 +1,16 @@ -import platform - import pytest from test_runner import run class TestCliModelPullCortexso: - @pytest.mark.skipif(True, reason="Expensive test. Only test when needed.") def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( - "Pull model", ["pull", "tinyllama"], - timeout=None + "Pull model", + ["pull", "tinyllama"], + timeout=None, ) assert exit_code == 0, f"Model pull failed with error: {error}" # TODO: verify that the model has been pull successfully - # TODO: skip this test. since download model is taking too long \ No newline at end of file + # TODO: skip this test. since download model is taking too long + diff --git a/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py new file mode 100644 index 000000000..50b7e832b --- /dev/null +++ b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py @@ -0,0 +1,28 @@ +import pytest +from test_runner import popen + + +class TestCliModelPullHuggingFaceRepository: + + def test_model_pull_hugging_face_repository(self): + """ + Test pull model pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF from issue #1017 + """ + + stdout, stderr, return_code = popen( + ["pull", "pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF"], "1\n" + ) + + assert "downloaded successfully!" in stdout + assert return_code == 0 + + def test_model_pull_hugging_face_not_gguf_should_failed_gracefully(self): + """ + When pull a model which is not GGUF, we stop and show a message to user + """ + + stdout, stderr, return_code = popen(["pull", "BAAI/bge-reranker-v2-m3"], "") + assert ( + "Not a GGUF model. Currently, only GGUF single file is supported." in stdout + ) + assert return_code == 0 diff --git a/engine/e2e-test/test_runner.py b/engine/e2e-test/test_runner.py index dd634d747..7716c55a0 100644 --- a/engine/e2e-test/test_runner.py +++ b/engine/e2e-test/test_runner.py @@ -38,6 +38,26 @@ def run(test_name: str, arguments: List[str], timeout=timeout) -> (int, str, str return result.returncode, result.stdout, result.stderr +def popen(arguments: List[str], user_input: str) -> (int, str, str): + # Start the process + executable_path = getExecutablePath() + process = subprocess.Popen( + [executable_path] + arguments, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, # This ensures the input and output are treated as text + ) + + # Send input and get output + stdout, stderr = process.communicate(input=user_input) + + # Get the return code + return_code = process.returncode + + return stdout, stderr, return_code + + # Start the API server # Wait for `Server started` message or failed def start_server() -> bool: @@ -50,10 +70,10 @@ def start_server() -> bool: def start_server_nix() -> bool: executable = getExecutablePath() process = subprocess.Popen( - [executable] + ['start', '-p', '3928'], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True + [executable] + ["start", "-p", "3928"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, ) start_time = time.time() @@ -80,7 +100,7 @@ def start_server_nix() -> bool: def start_server_windows() -> bool: executable = getExecutablePath() process = subprocess.Popen( - [executable] + ['start', '-p', '3928'], + [executable] + ["start", "-p", "3928"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 7943cace4..29575dfab 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -1,12 +1,14 @@ #include "model_service.h" #include #include -#include "commands/cmd_info.h" +#include +#include "utils/cli_selection_utils.h" #include "utils/cortexso_parser.h" #include "utils/file_manager_utils.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" #include "utils/model_callback_utils.h" -#include "utils/url_parser.h" +#include "utils/string_utils.h" void ModelService::DownloadModel(const std::string& input) { if (input.empty()) { @@ -14,18 +16,48 @@ void ModelService::DownloadModel(const std::string& input) { "Input must be Cortex Model Hub handle or HuggingFace url!"); } - // case input is a direct url - auto url_obj = url_parser::FromUrlString(input); - // TODO: handle case user paste url from cortexso - if (url_obj.protocol == "https") { - if (url_obj.host != kHuggingFaceHost) { - CLI_LOG("Only huggingface.co is supported for now"); + if (string_utils::StartsWith(input, "https://")) { + return DownloadModelByDirectUrl(input); + } + + if (input.find("/") != std::string::npos) { + auto parsed = string_utils::SplitBy(input, "/"); + if (parsed.size() != 2) { + throw std::runtime_error("Invalid model handle: " + input); + } + + auto author = parsed[0]; + auto model_name = parsed[1]; + if (author == "cortexso") { + return DownloadModelByModelName(model_name); + } + + DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); + CLI_LOG("Model " << model_name << " downloaded successfully!") + return; + } + + return DownloadModelByModelName(input); +} + +void ModelService::DownloadModelByModelName(const std::string& modelName) { + try { + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", modelName); + std::vector options{}; + for (const auto& branch : branches) { + if (branch.name != "main") { + options.emplace_back(branch.name); + } + } + if (options.empty()) { + CLI_LOG("No variant found"); return; } - return DownloadModelByDirectUrl(input); - } else { - commands::CmdInfo ci(input); - return DownloadModelFromCortexso(ci.model_name, ci.branch); + auto selection = cli_selection_utils::PrintSelection(options); + DownloadModelFromCortexso(modelName, selection.value()); + } catch (const std::runtime_error& e) { + CLI_LOG("Error downloading model, " << e.what()); } } @@ -56,20 +88,14 @@ std::optional ModelService::GetDownloadedModel( } void ModelService::DownloadModelByDirectUrl(const std::string& url) { - // check for malformed url - // question: What if the url is from cortexso itself - // answer: then route to download from cortexso auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { - // goto hugging face parser to normalize the url - // loop through path params, replace blob to resolve if any if (url_obj.pathParams[2] == "blob") { url_obj.pathParams[2] = "resolve"; } } - // should separate this function out auto model_id{url_obj.pathParams[1]}; auto file_name{url_obj.pathParams.back()}; @@ -86,7 +112,7 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { auto download_url = url_parser::FromUrl(url_obj); // this assume that the model being downloaded is a single gguf file - auto downloadTask{DownloadTask{.id = url_obj.pathParams.back(), + auto downloadTask{DownloadTask{.id = model_id, .type = DownloadType::Model, .items = {DownloadItem{ .id = url_obj.pathParams.back(), @@ -95,7 +121,7 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { }}}}; auto on_finished = [](const DownloadTask& finishedTask) { - std::cout << "Download success" << std::endl; + CLI_LOG("Model " << finishedTask.id << " downloaded successfully!") auto gguf_download_item = finishedTask.items[0]; model_callback_utils::ParseGguf(gguf_download_item); }; @@ -109,8 +135,38 @@ void ModelService::DownloadModelFromCortexso(const std::string& name, if (downloadTask.has_value()) { DownloadService().AddDownloadTask(downloadTask.value(), model_callback_utils::DownloadModelCb); - CTL_INF("Download finished"); + CLI_LOG("Model " << name << " downloaded successfully!") } else { CTL_ERR("Model not found"); } } + +void ModelService::DownloadHuggingFaceGgufModel( + const std::string& author, const std::string& modelName, + std::optional fileName) { + auto repo_info = + huggingface_utils::GetHuggingFaceModelRepoInfo(author, modelName); + if (!repo_info.has_value()) { + // throw is better? + CTL_ERR("Model not found"); + return; + } + + if (!repo_info->gguf.has_value()) { + throw std::runtime_error( + "Not a GGUF model. Currently, only GGUF single file is supported."); + } + + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + options.push_back(sibling.rfilename); + } + } + auto selection = cli_selection_utils::PrintSelection(options); + std::cout << "Selected: " << selection.value() << std::endl; + + auto download_url = huggingface_utils::GetDownloadableUrl(author, modelName, + selection.value()); + DownloadModelByDirectUrl(download_url); +} diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 81ec4e4b3..06212aaee 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -17,7 +17,16 @@ class ModelService { void DownloadModelByDirectUrl(const std::string& url); void DownloadModelFromCortexso(const std::string& name, - const std::string& branch); + const std::string& branch = "main"); + + /** + * Handle downloading model which have following pattern: author/model_name + */ + void DownloadHuggingFaceGgufModel(const std::string& author, + const std::string& modelName, + std::optional fileName); + + void DownloadModelByModelName(const std::string& modelName); DownloadService download_service_; diff --git a/engine/test/components/CMakeLists.txt b/engine/test/components/CMakeLists.txt index db810ad26..fa1c5477e 100644 --- a/engine/test/components/CMakeLists.txt +++ b/engine/test/components/CMakeLists.txt @@ -9,9 +9,12 @@ find_package(Drogon CONFIG REQUIRED) find_package(GTest CONFIG REQUIRED) find_package(yaml-cpp CONFIG REQUIRED) find_package(jinja2cpp CONFIG REQUIRED) +find_package(httplib CONFIG REQUIRED) target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp jinja2cpp ${CMAKE_THREAD_LIBS_INIT}) + +target_link_libraries(${PROJECT_NAME} PRIVATE httplib::httplib) target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../) add_test(NAME ${PROJECT_NAME} diff --git a/engine/test/components/test_huggingface_utils.cc b/engine/test/components/test_huggingface_utils.cc new file mode 100644 index 000000000..b1d949d22 --- /dev/null +++ b/engine/test/components/test_huggingface_utils.cc @@ -0,0 +1,88 @@ +#include "gtest/gtest.h" +#include "utils/huggingface_utils.h" + +class HuggingFaceUtilTestSuite : public ::testing::Test {}; + +TEST_F(HuggingFaceUtilTestSuite, TestGetModelRepositoryBranches) { + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", "tinyllama"); + + EXPECT_EQ(branches.size(), 3); + EXPECT_EQ(branches[0].name, "gguf"); + EXPECT_EQ(branches[0].ref, "refs/heads/gguf"); + EXPECT_EQ(branches[1].name, "1b-gguf"); + EXPECT_EQ(branches[1].ref, "refs/heads/1b-gguf"); + EXPECT_EQ(branches[2].name, "main"); + EXPECT_EQ(branches[2].ref, "refs/heads/main"); +} + +TEST_F(HuggingFaceUtilTestSuite, TestGetHuggingFaceModelRepoInfoSuccessfully) { + auto model_info = + huggingface_utils::GetHuggingFaceModelRepoInfo("cortexso", "tinyllama"); + auto not_null = model_info.has_value(); + + EXPECT_TRUE(not_null); + EXPECT_EQ(model_info->id, "cortexso/tinyllama"); + EXPECT_EQ(model_info->modelId, "cortexso/tinyllama"); + EXPECT_EQ(model_info->author, "cortexso"); + EXPECT_EQ(model_info->disabled, false); + EXPECT_EQ(model_info->gated, false); + + auto tag_contains_gguf = + std::find(model_info->tags.begin(), model_info->tags.end(), "gguf") != + model_info->tags.end(); + EXPECT_TRUE(tag_contains_gguf); + + auto contain_gguf_info = model_info->gguf.has_value(); + EXPECT_TRUE(contain_gguf_info); + + auto sibling_not_empty = !model_info->siblings.empty(); + EXPECT_TRUE(sibling_not_empty); +} + +TEST_F(HuggingFaceUtilTestSuite, + TestGetHuggingFaceModelRepoInfoReturnNullGgufInfoWhenNotAGgufModel) { + auto model_info = huggingface_utils::GetHuggingFaceModelRepoInfo( + "BAAI", "bge-reranker-v2-m3"); + auto not_null = model_info.has_value(); + + EXPECT_TRUE(not_null); + EXPECT_EQ(model_info->disabled, false); + EXPECT_EQ(model_info->gated, false); + + auto tag_not_contain_gguf = + std::find(model_info->tags.begin(), model_info->tags.end(), "gguf") == + model_info->tags.end(); + EXPECT_TRUE(tag_not_contain_gguf); + + auto contain_gguf_info = model_info->gguf.has_value(); + EXPECT_TRUE(!contain_gguf_info); + + auto sibling_not_empty = !model_info->siblings.empty(); + EXPECT_TRUE(sibling_not_empty); +} + +TEST_F(HuggingFaceUtilTestSuite, + TestGetHuggingFaceDownloadUrlWithoutBranchName) { + auto downloadable_url = huggingface_utils::GetDownloadableUrl( + "pervll", "bge-reranker-v2-gemma-Q4_K_M-GGUF", + "bge-reranker-v2-gemma-q4_k_m.gguf"); + + auto expected_url{ + "https://huggingface.co/pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF/resolve/" + "main/bge-reranker-v2-gemma-q4_k_m.gguf"}; + + EXPECT_EQ(downloadable_url, expected_url); +} + +TEST_F(HuggingFaceUtilTestSuite, TestGetHuggingFaceDownloadUrlWithBranchName) { + auto downloadable_url = huggingface_utils::GetDownloadableUrl( + "pervll", "bge-reranker-v2-gemma-Q4_K_M-GGUF", + "bge-reranker-v2-gemma-q4_k_m.gguf", "1b-gguf"); + + auto expected_url{ + "https://huggingface.co/pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF/resolve/" + "1b-gguf/bge-reranker-v2-gemma-q4_k_m.gguf"}; + + EXPECT_EQ(downloadable_url, expected_url); +} diff --git a/engine/test/components/test_string_utils.cc b/engine/test/components/test_string_utils.cc new file mode 100644 index 000000000..7a51b4f58 --- /dev/null +++ b/engine/test/components/test_string_utils.cc @@ -0,0 +1,79 @@ +#include "gtest/gtest.h" +#include "utils/string_utils.h" + +class StringUtilsTestSuite : public ::testing::Test {}; + +TEST_F(StringUtilsTestSuite, TestSplitBy) { + auto input = "this is a test"; + std::string delimiter{' '}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 4); + EXPECT_EQ(result[0], "this"); + EXPECT_EQ(result[1], "is"); + EXPECT_EQ(result[2], "a"); + EXPECT_EQ(result[3], "test"); +} + +TEST_F(StringUtilsTestSuite, TestSplitByWithEmptyString) { + auto input = ""; + std::string delimiter{' '}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 0); +} + +TEST_F(StringUtilsTestSuite, TestSplitModelHandle) { + auto input = "cortexso/tinyllama"; + std::string delimiter{'/'}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result[0], "cortexso"); + EXPECT_EQ(result[1], "tinyllama"); +} + +TEST_F(StringUtilsTestSuite, TestSplitModelHandleWithEmptyModelName) { + auto input = "cortexso/"; + std::string delimiter{'/'}; + auto result = string_utils::SplitBy(input, delimiter); + + EXPECT_EQ(result.size(), 1); + EXPECT_EQ(result[0], "cortexso"); +} + +TEST_F(StringUtilsTestSuite, TestStartsWith) { + auto input = "this is a test"; + auto prefix = "this"; + EXPECT_TRUE(string_utils::StartsWith(input, prefix)); +} + +TEST_F(StringUtilsTestSuite, TestStartsWithWithEmptyString) { + auto input = ""; + auto prefix = "this"; + EXPECT_FALSE(string_utils::StartsWith(input, prefix)); +} + +TEST_F(StringUtilsTestSuite, TestStartsWithWithEmptyPrefix) { + auto input = "this is a test"; + auto prefix = ""; + EXPECT_TRUE(string_utils::StartsWith(input, prefix)); +} + +TEST_F(StringUtilsTestSuite, TestEndsWith) { + auto input = "this is a test"; + auto suffix = "test"; + EXPECT_TRUE(string_utils::EndsWith(input, suffix)); +} + +TEST_F(StringUtilsTestSuite, TestEndsWithWithEmptyString) { + auto input = ""; + auto suffix = "test"; + EXPECT_FALSE(string_utils::EndsWith(input, suffix)); +} + +TEST_F(StringUtilsTestSuite, TestEndsWithWithEmptySuffix) { + auto input = "this is a test"; + auto suffix = ""; + EXPECT_TRUE(string_utils::EndsWith(input, suffix)); +} diff --git a/engine/test/components/test_url_parser.cc b/engine/test/components/test_url_parser.cc index decffcbf8..cee6cb6ed 100644 --- a/engine/test/components/test_url_parser.cc +++ b/engine/test/components/test_url_parser.cc @@ -1,4 +1,3 @@ -#include #include "gtest/gtest.h" #include "utils/url_parser.h" diff --git a/engine/utils/cli_selection_utils.h b/engine/utils/cli_selection_utils.h new file mode 100644 index 000000000..d3848c5bb --- /dev/null +++ b/engine/utils/cli_selection_utils.h @@ -0,0 +1,35 @@ +#include +#include +#include +#include + +namespace cli_selection_utils { +inline void PrintMenu(const std::vector& options) { + auto index{1}; + for (const auto& option : options) { + std::cout << index << ". " << option << "\n"; + index++; + } + std::endl(std::cout); +} + +inline std::optional PrintSelection( + const std::vector& options, + const std::string& title = "Select an option") { + std::cout << title << "\n"; + std::string selection{""}; + PrintMenu(options); + std::cout << "Select an option (" << 1 << "-" << options.size() << "): "; + std::cin >> selection; + + if (selection.empty()) { + return std::nullopt; + } + + if (std::stoi(selection) > options.size() || std::stoi(selection) < 1) { + return std::nullopt; + } + + return options[std::stoi(selection) - 1]; +} +} // namespace cli_selection_utils diff --git a/engine/utils/huggingface_utils.h b/engine/utils/huggingface_utils.h new file mode 100644 index 000000000..3699c38b8 --- /dev/null +++ b/engine/utils/huggingface_utils.h @@ -0,0 +1,157 @@ +#pragma once + +#include +#include +#include +#include +#include +#include "utils/json.hpp" +#include "utils/url_parser.h" + +namespace huggingface_utils { + +constexpr static auto kHuggingfaceHost{"huggingface.co"}; + +struct HuggingFaceBranch { + std::string name; + std::string ref; + std::string targetCommit; +}; + +struct HuggingFaceFileSibling { + std::string rfilename; +}; + +struct HuggingFaceGgufInfo { + uint64_t total; + std::string architecture; +}; + +struct HuggingFaceModelRepoInfo { + std::string id; + std::string modelId; + std::string author; + std::string sha; + std::string lastModified; + + bool isPrivate; + bool disabled; + bool gated; + std::vector tags; + int downloads; + + int likes; + std::optional gguf; + std::vector siblings; + std::vector spaces; + std::string createdAt; +}; + +inline std::vector GetModelRepositoryBranches( + const std::string& author, const std::string& modelName) { + if (author.empty() || modelName.empty()) { + throw std::runtime_error("Author and model name cannot be empty"); + } + auto url_obj = url_parser::Url{ + .protocol = "https", + .host = kHuggingfaceHost, + .pathParams = {"api", "models", author, modelName, "refs"}}; + + httplib::Client cli(url_obj.GetProtocolAndHost()); + auto res = cli.Get(url_obj.GetPathAndQuery()); + if (res->status != httplib::StatusCode::OK_200) { + throw std::runtime_error( + "Failed to get model repository branches: " + author + "/" + modelName); + } + + using json = nlohmann::json; + auto body = json::parse(res->body); + auto branches_json = body["branches"]; + + std::vector branches{}; + + for (const auto& branch : branches_json) { + branches.push_back(HuggingFaceBranch{ + .name = branch["name"], + .ref = branch["ref"], + .targetCommit = branch["targetCommit"], + }); + } + + return branches; +} + +// only support gguf for now +inline std::optional GetHuggingFaceModelRepoInfo( + const std::string& author, const std::string& modelName) { + if (author.empty() || modelName.empty()) { + throw std::runtime_error("Author and model name cannot be empty"); + } + auto url_obj = + url_parser::Url{.protocol = "https", + .host = kHuggingfaceHost, + .pathParams = {"api", "models", author, modelName}}; + + httplib::Client cli(url_obj.GetProtocolAndHost()); + auto res = cli.Get(url_obj.GetPathAndQuery()); + if (res->status != httplib::StatusCode::OK_200) { + throw std::runtime_error("Failed to get model repository info: " + author + + "/" + modelName); + } + + using json = nlohmann::json; + auto body = json::parse(res->body); + + std::optional gguf = std::nullopt; + auto gguf_info = body["gguf"]; + if (!gguf_info.is_null()) { + gguf = HuggingFaceGgufInfo{ + .total = gguf_info["total"], + .architecture = gguf_info["architecture"], + }; + } + + std::vector siblings{}; + auto siblings_info = body["siblings"]; + for (const auto& sibling : siblings_info) { + auto sibling_info = HuggingFaceFileSibling{ + .rfilename = sibling["rfilename"], + }; + siblings.push_back(sibling_info); + } + + auto model_repo_info = HuggingFaceModelRepoInfo{ + .id = body["id"], + .modelId = body["modelId"], + .author = body["author"], + .sha = body["sha"], + .lastModified = body["lastModified"], + + .isPrivate = body["private"], + .disabled = body["disabled"], + .gated = body["gated"], + .tags = body["tags"], + .downloads = body["downloads"], + + .likes = body["likes"], + .gguf = gguf, + .siblings = siblings, + .spaces = body["spaces"], + .createdAt = body["createdAt"], + }; + + return model_repo_info; +} + +inline std::string GetDownloadableUrl(const std::string& author, + const std::string& modelName, + const std::string& fileName, + const std::string& branch = "main") { + auto url_obj = url_parser::Url{ + .protocol = "https", + .host = kHuggingfaceHost, + .pathParams = {author, modelName, "resolve", branch, fileName}, + }; + return url_parser::FromUrl(url_obj); +} +} // namespace huggingface_utils diff --git a/engine/utils/string_utils.h b/engine/utils/string_utils.h new file mode 100644 index 000000000..150b8a61f --- /dev/null +++ b/engine/utils/string_utils.h @@ -0,0 +1,32 @@ +#include +#include + +namespace string_utils { +inline bool StartsWith(const std::string& str, const std::string& prefix) { + return str.rfind(prefix, 0) == 0; +} + +inline bool EndsWith(const std::string& str, const std::string& suffix) { + if (str.length() >= suffix.length()) { + return (0 == str.compare(str.length() - suffix.length(), suffix.length(), + suffix)); + } + return false; +} + +inline std::vector SplitBy(const std::string& str, + const std::string& delimiter) { + std::vector tokens; + size_t prev = 0, pos = 0; + do { + pos = str.find(delimiter, prev); + if (pos == std::string::npos) + pos = str.length(); + std::string token = str.substr(prev, pos - prev); + if (!token.empty()) + tokens.push_back(token); + prev = pos + delimiter.length(); + } while (pos < str.length() && prev < str.length()); + return tokens; +} +} // namespace string_utils