diff --git a/docs/static/openapi/jan.json b/docs/static/openapi/jan.json index 8e05cf597..0f715456d 100644 --- a/docs/static/openapi/jan.json +++ b/docs/static/openapi/jan.json @@ -1651,6 +1651,15 @@ "value": "https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.1-GGUF/blob/main/mistral-7b-instruct-v0.1.Q2_K.gguf" } ] + }, + "id": { + "type": "string", + "description": "The id which will be used to register the model.", + "examples": [ + { + "value": "my-custom-model-id" + } + ] } } }, diff --git a/engine/cli/commands/engine_get_cmd.cc b/engine/cli/commands/engine_get_cmd.cc index 7d4c66bb5..d1bf26641 100644 --- a/engine/cli/commands/engine_get_cmd.cc +++ b/engine/cli/commands/engine_get_cmd.cc @@ -1,8 +1,9 @@ #include "engine_get_cmd.h" +#include +#include #include #include "httplib.h" -#include "json/json.h" #include "server_start_cmd.h" #include "utils/logging_utils.h" @@ -29,7 +30,6 @@ void EngineGetCmd::Exec(const std::string& host, int port, auto res = cli.Get("/v1/engines/" + engine_name); if (res) { if (res->status == httplib::StatusCode::OK_200) { - // CLI_LOG(res->body); Json::Value v; Json::Reader reader; reader.parse(res->body, v); @@ -39,7 +39,8 @@ void EngineGetCmd::Exec(const std::string& host, int port, v["status"].asString()}); } else { - CLI_LOG_ERROR("Failed to get engine list with status code: " << res->status); + CLI_LOG_ERROR( + "Failed to get engine list with status code: " << res->status); return; } } else { diff --git a/engine/cli/commands/model_list_cmd.cc b/engine/cli/commands/model_list_cmd.cc index a6be44d9d..41fe61d1c 100644 --- a/engine/cli/commands/model_list_cmd.cc +++ b/engine/cli/commands/model_list_cmd.cc @@ -17,8 +17,9 @@ using namespace tabulate; using Row_t = std::vector>; -void ModelListCmd::Exec(const std::string& host, int port, std::string filter, - bool display_engine, bool display_version) { +void ModelListCmd::Exec(const std::string& host, int port, + const std::string& filter, bool display_engine, + bool display_version) { // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); diff --git a/engine/cli/commands/model_list_cmd.h b/engine/cli/commands/model_list_cmd.h index 4f61c67cc..2e7c446e7 100644 --- a/engine/cli/commands/model_list_cmd.h +++ b/engine/cli/commands/model_list_cmd.h @@ -6,7 +6,7 @@ namespace commands { class ModelListCmd { public: - void Exec(const std::string& host, int port, std::string filter, + void Exec(const std::string& host, int port, const std::string& filter, bool display_engine = false, bool display_version = false); }; } // namespace commands diff --git a/engine/common/download_task.h b/engine/common/download_task.h index 60da3ea86..5994cdaed 100644 --- a/engine/common/download_task.h +++ b/engine/common/download_task.h @@ -7,9 +7,11 @@ #include enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex }; + using namespace nlohmann; struct DownloadItem { + std::string id; std::string downloadUrl; @@ -54,8 +56,12 @@ inline std::string DownloadTypeToString(DownloadType type) { } struct DownloadTask { + enum class Status { Pending, InProgress, Completed, Cancelled, Error }; + std::string id; + Status status; + DownloadType type; std::vector items; diff --git a/engine/common/download_task_queue.h b/engine/common/download_task_queue.h new file mode 100644 index 000000000..5991687b7 --- /dev/null +++ b/engine/common/download_task_queue.h @@ -0,0 +1,77 @@ +#include +#include +#include +#include +#include +#include +#include +#include "common/download_task.h" + +class DownloadTaskQueue { + private: + std::deque taskQueue; + std::unordered_map::iterator> + taskMap; + mutable std::shared_mutex mutex; + std::condition_variable_any cv; + + public: + void push(DownloadTask task) { + std::unique_lock lock(mutex); + taskQueue.push_back(std::move(task)); + taskMap[taskQueue.back().id] = std::prev(taskQueue.end()); + cv.notify_one(); + } + + std::optional pop() { + std::unique_lock lock(mutex); + if (taskQueue.empty()) { + return std::nullopt; + } + DownloadTask task = std::move(taskQueue.front()); + taskQueue.pop_front(); + taskMap.erase(task.id); + return task; + } + + bool cancelTask(const std::string& taskId) { + std::unique_lock lock(mutex); + auto it = taskMap.find(taskId); + if (it != taskMap.end()) { + it->second->status = DownloadTask::Status::Cancelled; + taskQueue.erase(it->second); + taskMap.erase(it); + return true; + } + return false; + } + + bool updateTaskStatus(const std::string& taskId, + DownloadTask::Status newStatus) { + std::unique_lock lock(mutex); + auto it = taskMap.find(taskId); + if (it != taskMap.end()) { + it->second->status = newStatus; + if (newStatus == DownloadTask::Status::Cancelled || + newStatus == DownloadTask::Status::Error) { + taskQueue.erase(it->second); + taskMap.erase(it); + } + return true; + } + return false; + } + + std::optional getNextPendingTask() { + std::shared_lock lock(mutex); + auto it = std::find_if( + taskQueue.begin(), taskQueue.end(), [](const DownloadTask& task) { + return task.status == DownloadTask::Status::Pending; + }); + + if (it != taskQueue.end()) { + return *it; + } + return std::nullopt; + } +}; diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index b6e03ad40..174c89184 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -1,5 +1,6 @@ #include "database/models.h" #include +#include #include "config/gguf_parser.h" #include "config/yaml_config.h" #include "models.h" @@ -26,15 +27,22 @@ void Models::PullModel(const HttpRequestPtr& req, return; } + std::optional desired_model_id = std::nullopt; + auto id = (*(req->getJsonObject())).get("id", "").asString(); + if (!id.empty()) { + desired_model_id = id; + } + auto handle_model_input = [&, model_handle]() -> cpp::result { CTL_INF("Handle model input, model handle: " + model_handle); if (string_utils::StartsWith(model_handle, "https")) { - return model_service_->HandleDownloadUrlAsync(model_handle); + return model_service_->HandleDownloadUrlAsync(model_handle, + desired_model_id); } else if (model_handle.find(":") != std::string::npos) { auto model_and_branch = string_utils::SplitBy(model_handle, ":"); return model_service_->DownloadModelFromCortexsoAsync( - model_and_branch[0], model_and_branch[1]); + model_and_branch[0], model_and_branch[1], desired_model_id); } return cpp::fail("Invalid model handle or not supported!"); @@ -107,7 +115,6 @@ void Models::ListModel( auto list_entry = modellist_handler.LoadModelList(); if (list_entry) { for (const auto& model_entry : list_entry.value()) { - // auto model_entry = modellist_handler.GetModelInfo(model_handle); try { yaml_handler.ModelConfigFromFile( fmu::ToAbsoluteCortexDataPath( @@ -116,7 +123,6 @@ void Models::ListModel( auto model_config = yaml_handler.GetModelConfig(); Json::Value obj = model_config.ToJson(); obj["id"] = model_entry.model; - obj["model_alias"] = model_entry.model_alias; obj["model"] = model_entry.model; data.append(std::move(obj)); yaml_handler.Reset(); @@ -156,7 +162,6 @@ void Models::GetModel(const HttpRequestPtr& req, config::YamlHandler yaml_handler; auto model_entry = modellist_handler.GetModelInfo(model_id); if (model_entry.has_error()) { - // CLI_LOG("Error: " + model_entry.error()); ret["id"] = model_id; ret["object"] = "model"; ret["result"] = "Fail to get model information"; @@ -333,71 +338,6 @@ void Models::ImportModel( } } -void Models::SetModelAlias( - const HttpRequestPtr& req, - std::function&& callback) const { - if (!http_util::HasFieldInReq(req, callback, "model") || - !http_util::HasFieldInReq(req, callback, "modelAlias")) { - return; - } - auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); - auto model_alias = (*(req->getJsonObject())).get("modelAlias", "").asString(); - LOG_DEBUG << "GetModel, Model handle: " << model_handle - << ", Model alias: " << model_alias; - - cortex::db::Models modellist_handler; - try { - auto result = modellist_handler.UpdateModelAlias(model_handle, model_alias); - if (result.has_error()) { - std::string message = result.error(); - LOG_ERROR << message; - Json::Value ret; - ret["result"] = "Set alias failed!"; - ret["modelHandle"] = model_handle; - ret["message"] = message; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); - resp->setStatusCode(k400BadRequest); - callback(resp); - } else { - if (result.value()) { - std::string message = "Successfully set model alias '" + model_alias + - "' for modeID '" + model_handle + "'."; - LOG_INFO << message; - Json::Value ret; - ret["result"] = "OK"; - ret["modelHandle"] = model_handle; - ret["message"] = message; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); - resp->setStatusCode(k200OK); - callback(resp); - } else { - std::string message = "Unable to set model alias for modelID '" + - model_handle + "': model alias '" + model_alias + - "' is not unique!"; - LOG_ERROR << message; - Json::Value ret; - ret["result"] = "Set alias failed!"; - ret["modelHandle"] = model_handle; - ret["message"] = message; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); - resp->setStatusCode(k400BadRequest); - callback(resp); - } - } - } catch (const std::exception& e) { - std::string message = "Error when setting model alias ('" + model_alias + - "') for modelID '" + model_handle + "':" + e.what(); - LOG_ERROR << message; - Json::Value ret; - ret["result"] = "Set alias failed!"; - ret["modelHandle"] = model_handle; - ret["message"] = message; - auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); - resp->setStatusCode(k400BadRequest); - callback(resp); - } -} - void Models::StartModel( const HttpRequestPtr& req, std::function&& callback) { @@ -407,6 +347,34 @@ void Models::StartModel( auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); auto custom_prompt_template = (*(req->getJsonObject())).get("prompt_template", "").asString(); + auto model_entry = model_service_->GetDownloadedModel(model_handle); + if (!model_entry.has_value()) { + Json::Value ret; + ret["message"] = "Cannot find model: " + model_handle; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + auto engine_name = model_entry.value().engine; + auto engine_entry = engine_service_->GetEngineInfo(engine_name); + if (engine_entry.has_error()) { + Json::Value ret; + ret["message"] = "Cannot find engine: " + engine_name; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + if (engine_entry->status != "Ready") { + Json::Value ret; + ret["message"] = "Engine is not ready! Please install first!"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(drogon::k400BadRequest); + callback(resp); + return; + } + auto result = model_service_->StartModel( config.apiServerHost, std::stoi(config.apiServerPort), model_handle, custom_prompt_template); diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 4482ebcbd..cacec2e48 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -2,6 +2,7 @@ #include #include +#include "services/engine_service.h" #include "services/model_service.h" using namespace drogon; @@ -16,7 +17,6 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::UpdateModel, "/{1}", Patch); METHOD_ADD(Models::ImportModel, "/import", Post); METHOD_ADD(Models::DeleteModel, "/{1}", Delete); - METHOD_ADD(Models::SetModelAlias, "/alias", Post); METHOD_ADD(Models::StartModel, "/start", Post); METHOD_ADD(Models::StopModel, "/stop", Post); METHOD_ADD(Models::GetModelStatus, "/status/{1}", Get); @@ -28,14 +28,14 @@ class Models : public drogon::HttpController { ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Patch); ADD_METHOD_TO(Models::ImportModel, "/v1/models/import", Post); ADD_METHOD_TO(Models::DeleteModel, "/v1/models/{1}", Delete); - ADD_METHOD_TO(Models::SetModelAlias, "/v1/models/alias", Post); ADD_METHOD_TO(Models::StartModel, "/v1/models/start", Post); ADD_METHOD_TO(Models::StopModel, "/v1/models/stop", Post); ADD_METHOD_TO(Models::GetModelStatus, "/v1/models/status/{1}", Get); METHOD_LIST_END - explicit Models(std::shared_ptr model_service) - : model_service_{model_service} {} + explicit Models(std::shared_ptr model_service, + std::shared_ptr engine_service) + : model_service_{model_service}, engine_service_{engine_service} {} void PullModel(const HttpRequestPtr& req, std::function&& callback); @@ -71,4 +71,5 @@ class Models : public drogon::HttpController { private: std::shared_ptr model_service_; + std::shared_ptr engine_service_; }; diff --git a/engine/database/models.h b/engine/database/models.h index 3248da788..ebb006b28 100644 --- a/engine/database/models.h +++ b/engine/database/models.h @@ -1,9 +1,9 @@ #pragma once +#include #include #include #include -#include "SQLiteCpp/SQLiteCpp.h" #include "utils/result.hpp" namespace cortex::db { diff --git a/engine/e2e-test/main.py b/engine/e2e-test/main.py index f814c45a5..add2354f3 100644 --- a/engine/e2e-test/main.py +++ b/engine/e2e-test/main.py @@ -11,7 +11,6 @@ from test_api_model_start import TestApiModelStart from test_api_model_stop import TestApiModelStop from test_api_model_get import TestApiModelGet -from test_api_model_alias import TestApiModelAlias from test_api_model_list import TestApiModelList from test_api_model_update import TestApiModelUpdate from test_api_model_delete import TestApiModelDelete diff --git a/engine/e2e-test/test_api_model_alias.py b/engine/e2e-test/test_api_model_alias.py deleted file mode 100644 index 1a17ad9e0..000000000 --- a/engine/e2e-test/test_api_model_alias.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest -import requests -from test_runner import popen, run -from test_runner import start_server, stop_server - - -class TestApiModelAlias: - - @pytest.fixture(autouse=True) - def setup_and_teardown(self): - # Setup - success = start_server() - if not success: - raise Exception("Failed to start server") - - yield - - stop_server() - - def test_models_set_alias_should_be_successful(self): - body_json = {'model': 'tinyllama:gguf', - 'modelAlias': 'tg'} - response = requests.post("http://localhost:3928/models/alias", json = body_json) - assert response.status_code == 200 diff --git a/engine/e2e-test/test_api_model_start.py b/engine/e2e-test/test_api_model_start.py index 906d4b0cf..216fad570 100644 --- a/engine/e2e-test/test_api_model_start.py +++ b/engine/e2e-test/test_api_model_start.py @@ -1,7 +1,6 @@ import pytest import requests -from test_runner import popen -from test_runner import start_server, stop_server, run +from test_runner import run, start_server, stop_server class TestApiModelStart: @@ -12,10 +11,13 @@ def setup_and_teardown(self): success = start_server() if not success: raise Exception("Failed to start server") - - # TODO: using pull with branch for easy testing tinyllama:gguf for example + run("Install Engine", ["engines", "install", "llama-cpp"], timeout=None) run("Delete model", ["models", "delete", "tinyllama:gguf"]) - run("Pull model", ["pull", "tinyllama:gguf"], timeout=None,) + run( + "Pull model", + ["pull", "tinyllama:gguf"], + timeout=None, + ) yield @@ -23,6 +25,6 @@ def setup_and_teardown(self): stop_server() def test_models_start_should_be_successful(self): - json_body = {'model': 'tinyllama:gguf'} - response = requests.post("http://localhost:3928/models/start", json = json_body) + json_body = {"model": "tinyllama:gguf"} + response = requests.post("http://localhost:3928/models/start", json=json_body) assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/e2e-test/test_api_model_stop.py b/engine/e2e-test/test_api_model_stop.py index a787be276..00d7482fa 100644 --- a/engine/e2e-test/test_api_model_stop.py +++ b/engine/e2e-test/test_api_model_stop.py @@ -1,6 +1,6 @@ import pytest import requests -from test_runner import start_server, stop_server +from test_runner import run, start_server, stop_server class TestApiModelStop: @@ -12,14 +12,15 @@ def setup_and_teardown(self): if not success: raise Exception("Failed to start server") + run("Install Engine", ["engines", "install", "llama-cpp"], timeout=None) yield # Teardown stop_server() def test_models_stop_should_be_successful(self): - json_body = {'model': 'tinyllama:gguf'} - response = requests.post("http://localhost:3928/models/start", json = json_body) + json_body = {"model": "tinyllama:gguf"} + response = requests.post("http://localhost:3928/models/start", json=json_body) assert response.status_code == 200, f"status_code: {response.status_code}" - response = requests.post("http://localhost:3928/models/stop", json = json_body) + response = requests.post("http://localhost:3928/models/stop", json=json_body) assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/main.cc b/engine/main.cc index f66deee0c..0542087c4 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -95,7 +95,7 @@ void RunServer(std::optional port) { // initialize custom controllers auto engine_ctl = std::make_shared(engine_service); - auto model_ctl = std::make_shared(model_service); + auto model_ctl = std::make_shared(model_service, engine_service); auto event_ctl = std::make_shared(event_queue_ptr); auto pm_ctl = std::make_shared(); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 58051ebc5..57a26c1be 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -198,32 +198,32 @@ cpp::result ModelService::HandleCortexsoModel( std::optional ModelService::GetDownloadedModel( const std::string& modelId) const { - auto models_path = file_manager_utils::GetModelsContainerPath(); - if (!std::filesystem::exists(models_path) || - !std::filesystem::is_directory(models_path)) { + + cortex::db::Models modellist_handler; + config::YamlHandler yaml_handler; + auto model_entry = modellist_handler.GetModelInfo(modelId); + if (!model_entry.has_value()) { return std::nullopt; } - for (const auto& entry : std::filesystem::directory_iterator(models_path)) { - if (entry.is_regular_file() && - entry.path().filename().string() == modelId && - entry.path().extension() == ".yaml") { - try { - config::YamlHandler handler; - handler.ModelConfigFromFile(entry.path().string()); - auto model_conf = handler.GetModelConfig(); - return model_conf; - } catch (const std::exception& e) { - LOG_ERROR << "Error reading yaml file '" << entry.path().string() - << "': " << e.what(); - } - } + try { + config::YamlHandler yaml_handler; + namespace fs = std::filesystem; + namespace fmu = file_manager_utils; + yaml_handler.ModelConfigFromFile( + fmu::ToAbsoluteCortexDataPath( + fs::path(model_entry.value().path_to_model_yaml)) + .string()); + return yaml_handler.GetModelConfig(); + } catch (const std::exception& e) { + LOG_ERROR << "Error reading yaml file '" << model_entry->path_to_model_yaml + << "': " << e.what(); + return std::nullopt; } - return std::nullopt; } cpp::result ModelService::HandleDownloadUrlAsync( - const std::string& url) { + const std::string& url, std::optional temp_model_id) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { @@ -244,11 +244,15 @@ cpp::result ModelService::HandleDownloadUrlAsync( } std::string huggingFaceHost{kHuggingFaceHost}; - std::string unique_model_id{author + ":" + model_id + ":" + file_name}; - cortex::db::Models modellist_handler; - auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + std::string unique_model_id = ""; + if (temp_model_id.has_value()) { + unique_model_id = temp_model_id.value(); + } else { + unique_model_id = author + ":" + model_id + ":" + file_name; + } + auto model_entry = modellist_handler.GetModelInfo(unique_model_id); if (model_entry.has_value()) { CLI_LOG("Model already downloaded: " << unique_model_id); return cpp::fail("Please delete the model before downloading again"); @@ -356,16 +360,28 @@ cpp::result ModelService::HandleUrl( } cpp::result -ModelService::DownloadModelFromCortexsoAsync(const std::string& name, - const std::string& branch) { +ModelService::DownloadModelFromCortexsoAsync( + const std::string& name, const std::string& branch, + std::optional temp_model_id) { auto download_task = GetDownloadTask(name, branch); if (download_task.has_error()) { return cpp::fail(download_task.error()); } - std::string model_id{name + ":" + branch}; - auto on_finished = [&, model_id](const DownloadTask& finishedTask) { + cortex::db::Models modellist_handler; + std::string unique_model_id = ""; + if (temp_model_id.has_value()) { + unique_model_id = temp_model_id.value(); + } else { + unique_model_id = name + ":" + branch; + } + + auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + if (model_entry.has_value()) { + return cpp::fail("Please delete the model before downloading again"); + } + auto on_finished = [&, unique_model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -376,7 +392,8 @@ ModelService::DownloadModelFromCortexsoAsync(const std::string& name, } if (model_yml_item == nullptr) { - CTL_WRN("model.yml not found in the downloaded files for " + model_id); + CTL_WRN("model.yml not found in the downloaded files for " + + unique_model_id); return; } auto url_obj = url_parser::FromUrlString(model_yml_item->downloadUrl); @@ -384,7 +401,7 @@ ModelService::DownloadModelFromCortexsoAsync(const std::string& name, config::YamlHandler yaml_handler; yaml_handler.ModelConfigFromFile(model_yml_item->localPath.string()); auto mc = yaml_handler.GetModelConfig(); - mc.model = model_id; + mc.model = unique_model_id; yaml_handler.UpdateModelConfig(mc); yaml_handler.WriteYamlFile(model_yml_item->localPath.string()); @@ -393,11 +410,11 @@ ModelService::DownloadModelFromCortexsoAsync(const std::string& name, CTL_INF("path_to_model_yaml: " << rel.string()); cortex::db::Models modellist_utils_obj; - cortex::db::ModelEntry model_entry{.model = model_id, + cortex::db::ModelEntry model_entry{.model = unique_model_id, .author_repo_id = "cortexso", .branch_name = branch, .path_to_model_yaml = rel.string(), - .model_alias = model_id}; + .model_alias = unique_model_id}; auto result = modellist_utils_obj.AddModelEntry(model_entry); if (result.has_error()) { CTL_ERR("Error adding model to modellist: " + result.error()); @@ -405,9 +422,8 @@ ModelService::DownloadModelFromCortexsoAsync(const std::string& name, }; auto task = download_task.value(); - task.id = model_id; - - return download_service_->AddTask(download_task.value(), on_finished); + task.id = unique_model_id; + return download_service_->AddTask(task, on_finished); } cpp::result ModelService::DownloadModelFromCortexso( diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 43e597c40..822b376ae 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -24,7 +24,8 @@ class ModelService { const std::string& name, const std::string& branch = "main"); cpp::result DownloadModelFromCortexsoAsync( - const std::string& name, const std::string& branch = "main"); + const std::string& name, const std::string& branch = "main", + std::optional temp_model_id = std::nullopt); std::optional GetDownloadedModel( const std::string& modelId) const; @@ -48,7 +49,7 @@ class ModelService { cpp::result HandleUrl(const std::string& url); cpp::result HandleDownloadUrlAsync( - const std::string& url); + const std::string& url, std::optional temp_model_id); private: /** diff --git a/engine/test/components/test_download_task_queue.cc b/engine/test/components/test_download_task_queue.cc new file mode 100644 index 000000000..526371399 --- /dev/null +++ b/engine/test/components/test_download_task_queue.cc @@ -0,0 +1,136 @@ +#include +#include +#include +#include "common/download_task_queue.h" + +class DownloadTaskQueueTest : public ::testing::Test { + protected: + DownloadTaskQueue queue; +}; + +DownloadTask CreateDownloadTask( + const std::string& id, + DownloadTask::Status staus = DownloadTask::Status::Pending) { + return DownloadTask{.id = id, + .status = DownloadTask::Status::Pending, + .type = DownloadType::Model, + .items = {}}; +} + +TEST_F(DownloadTaskQueueTest, PushAndPop) { + queue.push(CreateDownloadTask("task1")); + queue.push(CreateDownloadTask("task2")); + + auto task = queue.pop(); + ASSERT_TRUE(task.has_value()); + EXPECT_EQ(task->id, "task1"); + + task = queue.pop(); + ASSERT_TRUE(task.has_value()); + EXPECT_EQ(task->id, "task2"); + + task = queue.pop(); + EXPECT_FALSE(task.has_value()); +} + +TEST_F(DownloadTaskQueueTest, CancelTask) { + queue.push(CreateDownloadTask("task1")); + queue.push(CreateDownloadTask("task2")); + queue.push(CreateDownloadTask("task3")); + + EXPECT_TRUE(queue.cancelTask("task2")); + EXPECT_FALSE(queue.cancelTask("task4")); + + auto task = queue.pop(); + ASSERT_TRUE(task.has_value()); + EXPECT_EQ(task->id, "task1"); + + task = queue.pop(); + ASSERT_TRUE(task.has_value()); + EXPECT_EQ(task->id, "task3"); + + task = queue.pop(); + EXPECT_FALSE(task.has_value()); +} + +TEST_F(DownloadTaskQueueTest, PopEmptyQueue) { + auto task = queue.pop(); + EXPECT_FALSE(task.has_value()); +} + +TEST_F(DownloadTaskQueueTest, UpdateTaskStatus) { + queue.push(CreateDownloadTask("task1")); + + EXPECT_TRUE( + queue.updateTaskStatus("task1", DownloadTask::Status::InProgress)); + EXPECT_FALSE(queue.updateTaskStatus( + "task2", DownloadTask::Status::Completed)); // Non-existent task + + auto task = queue.getNextPendingTask(); + ASSERT_FALSE(task.has_value()); + + queue.push(CreateDownloadTask("task2")); + task = queue.getNextPendingTask(); + // task2 + EXPECT_EQ(task->id, "task2"); + EXPECT_TRUE( + queue.updateTaskStatus("task2", DownloadTask::Status::InProgress)); + + EXPECT_TRUE(queue.updateTaskStatus("task1", DownloadTask::Status::Completed)); + task = queue.pop(); + EXPECT_TRUE(queue.updateTaskStatus("task2", DownloadTask::Status::Completed)); + task = queue.pop(); + task = queue.pop(); + EXPECT_FALSE(task.has_value()); // Task should be removed after completion +} + +TEST_F(DownloadTaskQueueTest, GetNextPendingTask) { + queue.push(CreateDownloadTask("task1")); + queue.push(CreateDownloadTask("task2")); + queue.updateTaskStatus("task1", DownloadTask::Status::InProgress); + + auto task = queue.getNextPendingTask(); + ASSERT_TRUE(task.has_value()); + EXPECT_EQ(task->id, "task2"); + EXPECT_EQ(task->status, DownloadTask::Status::Pending); + + queue.updateTaskStatus("task2", DownloadTask::Status::InProgress); + task = queue.getNextPendingTask(); + EXPECT_FALSE(task.has_value()); +} + +TEST_F(DownloadTaskQueueTest, ConcurrentPushAndPop) { + const int numTasks = 10000; + std::vector pushThreads; + std::vector popThreads; + std::atomic pushedTasks{0}; + std::atomic poppedTasks{0}; + + for (int i = 0; i < 4; ++i) { + pushThreads.emplace_back([this, numTasks, i, &pushedTasks]() { + for (int j = 0; j < numTasks; ++j) { + queue.push(CreateDownloadTask("task_" + std::to_string(i) + "_" + + std::to_string(j))); + pushedTasks++; + } + }); + + popThreads.emplace_back([this, &poppedTasks, &pushedTasks]() { + while (poppedTasks.load() < pushedTasks.load() || + pushedTasks.load() < numTasks * 4) { + if (auto task = queue.pop()) { + poppedTasks++; + } + } + }); + } + + for (auto& t : pushThreads) + t.join(); + for (auto& t : popThreads) + t.join(); + + EXPECT_EQ(pushedTasks.load(), numTasks * 4); + EXPECT_EQ(poppedTasks.load(), numTasks * 4); + EXPECT_FALSE(queue.pop().has_value()); +}