From e5b7be2a2fd7d9f5b5a7cb3c8858550789515e42 Mon Sep 17 00:00:00 2001 From: James Date: Sun, 13 Oct 2024 20:42:49 +0700 Subject: [PATCH 1/6] remove some redundant code --- engine/commands/cortex_upd_cmd.cc | 5 ++--- engine/commands/cortex_upd_cmd.h | 8 +++++--- engine/controllers/command_line_parser.cc | 8 +++++--- engine/controllers/command_line_parser.h | 2 ++ 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/engine/commands/cortex_upd_cmd.cc b/engine/commands/cortex_upd_cmd.cc index fa3f0497a..2d2caea8a 100644 --- a/engine/commands/cortex_upd_cmd.cc +++ b/engine/commands/cortex_upd_cmd.cc @@ -2,7 +2,6 @@ #include "httplib.h" #include "nlohmann/json.hpp" #include "server_stop_cmd.h" -#include "services/download_service.h" #include "utils/archive_utils.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" @@ -395,7 +394,7 @@ bool CortexUpdCmd::HandleGithubRelease(const nlohmann::json& assets, .localPath = local_path, }}}}; - auto result = DownloadService().AddDownloadTask( + auto result = download_service_->AddDownloadTask( download_task, [](const DownloadTask& finishedTask) { // try to unzip the downloaded file CTL_INF("Downloaded engine path: " @@ -460,7 +459,7 @@ bool CortexUpdCmd::GetNightly(const std::string& v) { .localPath = localPath, }}}; - auto result = DownloadService().AddDownloadTask( + auto result = download_service_->AddDownloadTask( download_task, [](const DownloadTask& finishedTask) { // try to unzip the downloaded file CTL_INF("Downloaded engine path: " diff --git a/engine/commands/cortex_upd_cmd.h b/engine/commands/cortex_upd_cmd.h index cdd8816f1..e7d5cf289 100644 --- a/engine/commands/cortex_upd_cmd.h +++ b/engine/commands/cortex_upd_cmd.h @@ -5,10 +5,7 @@ #include #endif -#include "httplib.h" -#include "nlohmann/json.hpp" #include "utils/file_manager_utils.h" -#include "utils/logging_utils.h" namespace commands { #ifndef CORTEX_VARIANT @@ -81,9 +78,14 @@ bool ReplaceBinaryInflight(const std::filesystem::path& src, // - Nightly: Enables retrieval of the latest nightly build and specific versions using the -v flag class CortexUpdCmd { public: + explicit CortexUpdCmd(std::shared_ptr download_service) + : download_service_{download_service} {}; + void Exec(const std::string& v); private: + std::shared_ptr download_service_; + bool GetStable(const std::string& v); bool GetBeta(const std::string& v); bool HandleGithubRelease(const nlohmann::json& assets, diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index 2ee3157af..df83db688 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -32,10 +32,12 @@ constexpr const auto kEngineGroup = "Engines"; constexpr const auto kSystemGroup = "System"; constexpr const auto kSubcommands = "Subcommands"; } // namespace + CommandLineParser::CommandLineParser() : app_("Cortex.cpp CLI"), - model_service_{ModelService(std::make_shared())}, - engine_service_{EngineService(std::make_shared())} {} + download_service_{std::make_shared()}, + model_service_{ModelService(download_service_)}, + engine_service_{EngineService(download_service_)} {} bool CommandLineParser::SetupCommand(int argc, char** argv) { app_.usage("Usage:\n" + commands::GetCortexBinary() + @@ -452,7 +454,7 @@ void CommandLineParser::SetupSystemCommands() { return; } #endif - commands::CortexUpdCmd cuc; + auto cuc = commands::CortexUpdCmd(download_service_); cuc.Exec(cml_data_.cortex_version); cml_data_.check_upd = false; }); diff --git a/engine/controllers/command_line_parser.h b/engine/controllers/command_line_parser.h index 9e150605c..96cf886bf 100644 --- a/engine/controllers/command_line_parser.h +++ b/engine/controllers/command_line_parser.h @@ -1,5 +1,6 @@ #pragma once +#include #include "CLI/CLI.hpp" #include "services/engine_service.h" #include "services/model_service.h" @@ -29,6 +30,7 @@ class CommandLineParser { void ModelUpdate(CLI::App* parent); CLI::App app_; + std::shared_ptr download_service_; EngineService engine_service_; ModelService model_service_; From 9ec8566eec7059f0d34c403a29ab97b29b067a8c Mon Sep 17 00:00:00 2001 From: James Date: Mon, 14 Oct 2024 01:58:06 +0700 Subject: [PATCH 2/6] add stop download api --- engine/commands/server_stop_cmd.cc | 3 +- engine/commands/server_stop_cmd.h | 5 +- engine/common/download_task.h | 23 ++ engine/config/gguf_parser.cc | 3 +- engine/controllers/models.cc | 41 ++- engine/controllers/models.h | 4 + engine/controllers/processManager.h | 19 -- .../{processManager.cc => process_manager.cc} | 13 +- engine/controllers/process_manager.h | 23 ++ engine/main.cc | 3 + engine/services/download_service.cc | 245 ++++++++++++------ engine/services/download_service.h | 116 ++++++--- engine/services/model_service.cc | 166 ++++++++++-- engine/services/model_service.h | 16 +- 14 files changed, 505 insertions(+), 175 deletions(-) delete mode 100644 engine/controllers/processManager.h rename engine/controllers/{processManager.cc => process_manager.cc} (57%) create mode 100644 engine/controllers/process_manager.h diff --git a/engine/commands/server_stop_cmd.cc b/engine/commands/server_stop_cmd.cc index f3d83d6d2..4f48506cb 100644 --- a/engine/commands/server_stop_cmd.cc +++ b/engine/commands/server_stop_cmd.cc @@ -1,6 +1,5 @@ #include "server_stop_cmd.h" #include "httplib.h" -#include "trantor/utils/Logger.h" #include "utils/logging_utils.h" namespace commands { @@ -18,4 +17,4 @@ void ServerStopCmd::Exec() { } } -}; // namespace commands \ No newline at end of file +}; // namespace commands diff --git a/engine/commands/server_stop_cmd.h b/engine/commands/server_stop_cmd.h index 4beb0d05f..31add4b5b 100644 --- a/engine/commands/server_stop_cmd.h +++ b/engine/commands/server_stop_cmd.h @@ -1,9 +1,10 @@ #pragma once + #include namespace commands { -class ServerStopCmd{ +class ServerStopCmd { public: ServerStopCmd(std::string host, int port); void Exec(); @@ -12,4 +13,4 @@ class ServerStopCmd{ std::string host_; int port_; }; -} // namespace commands \ No newline at end of file +} // namespace commands diff --git a/engine/common/download_task.h b/engine/common/download_task.h index dd049ad17..60da3ea86 100644 --- a/engine/common/download_task.h +++ b/engine/common/download_task.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -70,6 +71,28 @@ struct DownloadTask { return output.str(); } + Json::Value ToJsonCpp() const { + Json::Value root; + root["id"] = id; + root["type"] = DownloadTypeToString(type); + + Json::Value itemsArray(Json::arrayValue); + for (const auto& item : items) { + Json::Value itemObj; + itemObj["id"] = item.id; + itemObj["downloadUrl"] = item.downloadUrl; + itemObj["localPath"] = item.localPath.string(); + itemObj["checksum"] = item.checksum.value_or("N/A"); + itemObj["bytes"] = Json::Value::UInt64(item.bytes.value_or(0)); + itemObj["downloadedBytes"] = + Json::Value::UInt64(item.downloadedBytes.value_or(0)); + itemsArray.append(itemObj); + } + root["items"] = itemsArray; + + return root; + } + json ToJson() const { json dl_items = json::array(); diff --git a/engine/config/gguf_parser.cc b/engine/config/gguf_parser.cc index 3324077c3..560c36eb8 100644 --- a/engine/config/gguf_parser.cc +++ b/engine/config/gguf_parser.cc @@ -588,7 +588,8 @@ void GGUFHandler::ModelConfigFromMetadata() { model_config_.model = name; model_config_.id = name; model_config_.version = std::to_string(version); - model_config_.max_tokens = std::min(kDefaultMaxContextLength, max_tokens); + model_config_.max_tokens = + std::min(kDefaultMaxContextLength, max_tokens); model_config_.ctx_len = std::min(kDefaultMaxContextLength, max_tokens); model_config_.ngl = ngl; } diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index eefc0a941..29ec8c0ed 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -27,14 +27,14 @@ void Models::PullModel(const HttpRequestPtr& req, } auto handle_model_input = - [&, model_handle]() -> cpp::result { + [&, model_handle]() -> cpp::result { CTL_INF("Handle model input, model handle: " + model_handle); if (string_utils::StartsWith(model_handle, "https")) { - return model_service_->HandleUrl(model_handle, true); + return model_service_->HandleDownloadUrlAsync(model_handle); } else if (model_handle.find(":") != std::string::npos) { auto model_and_branch = string_utils::SplitBy(model_handle, ":"); - return model_service_->DownloadModelFromCortexso( - model_and_branch[0], model_and_branch[1], true); + return model_service_->DownloadModelFromCortexsoAsync( + model_and_branch[0], model_and_branch[1]); } return cpp::fail("Invalid model handle or not supported!"); @@ -50,6 +50,39 @@ void Models::PullModel(const HttpRequestPtr& req, } else { Json::Value ret; ret["message"] = "Model start downloading!"; + ret["task"] = result.value().ToJsonCpp(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } +} + +void Models::AbortPullModel( + const HttpRequestPtr& req, + std::function&& callback) { + if (!http_util::HasFieldInReq(req, callback, "taskId")) { + return; + } + auto task_id = (*(req->getJsonObject())).get("taskId", "").asString(); + if (task_id.empty()) { + Json::Value ret; + ret["result"] = "Bad Request"; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + return; + } + + auto result = model_service_->AbortDownloadModel(task_id); + if (result.has_error()) { + Json::Value ret; + ret["message"] = result.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + Json::Value ret; + ret["message"] = "Task stopped!"; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); callback(resp); diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 2aa99f5e9..7542b4e43 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -10,6 +10,7 @@ class Models : public drogon::HttpController { public: METHOD_LIST_BEGIN METHOD_ADD(Models::PullModel, "/pull", Post); + METHOD_ADD(Models::AbortPullModel, "/pull", Delete); METHOD_ADD(Models::ListModel, "", Get); METHOD_ADD(Models::GetModel, "/{1}", Get); METHOD_ADD(Models::UpdateModel, "/{1}", Patch); @@ -20,6 +21,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::StopModel, "/stop", Post); ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Post); + ADD_METHOD_TO(Models::PullModel, "/v1/models/pull", Delete); ADD_METHOD_TO(Models::ListModel, "/v1/models", Get); ADD_METHOD_TO(Models::GetModel, "/v1/models/{1}", Get); ADD_METHOD_TO(Models::UpdateModel, "/v1/models/{1}", Patch); @@ -35,6 +37,8 @@ class Models : public drogon::HttpController { void PullModel(const HttpRequestPtr& req, std::function&& callback); + void AbortPullModel(const HttpRequestPtr& req, + std::function&& callback); void ListModel(const HttpRequestPtr& req, std::function&& callback) const; void GetModel(const HttpRequestPtr& req, diff --git a/engine/controllers/processManager.h b/engine/controllers/processManager.h deleted file mode 100644 index 7abfbe2d4..000000000 --- a/engine/controllers/processManager.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include -#include - -using namespace drogon; - -class processManager : public drogon::HttpController { -public: - METHOD_LIST_BEGIN - - METHOD_ADD(processManager::destroy, "/destroy", - Delete); // path is /processManager/{arg1}/{arg2}/list - - METHOD_LIST_END - - void destroy(const HttpRequestPtr &req, - std::function &&callback); -}; diff --git a/engine/controllers/processManager.cc b/engine/controllers/process_manager.cc similarity index 57% rename from engine/controllers/processManager.cc rename to engine/controllers/process_manager.cc index 15c213453..efacc3c50 100644 --- a/engine/controllers/processManager.cc +++ b/engine/controllers/process_manager.cc @@ -1,12 +1,21 @@ -#include "processManager.h" +#include "process_manager.h" #include "utils/cortex_utils.h" +#include "utils/logging_utils.h" #include #include -void processManager::destroy( +void ProcessManager::destroy( const HttpRequestPtr& req, std::function&& callback) { + + auto destroy_result = download_service_->Destroy(); + if (destroy_result.has_error()) { + CTL_ERR("Failed to destroy download service: " + destroy_result.error()); + } else { + CTL_INF("Download service stopped!"); + } + app().quit(); Json::Value ret; ret["message"] = "Program is exitting, goodbye!"; diff --git a/engine/controllers/process_manager.h b/engine/controllers/process_manager.h new file mode 100644 index 000000000..4cf4f81d4 --- /dev/null +++ b/engine/controllers/process_manager.h @@ -0,0 +1,23 @@ +#pragma once + +#include +#include +#include "services/download_service.h" + +using namespace drogon; + +class ProcessManager : public drogon::HttpController { + public: + METHOD_LIST_BEGIN + METHOD_ADD(ProcessManager::destroy, "/destroy", Delete); + METHOD_LIST_END + + explicit ProcessManager(std::shared_ptr download_service) + : download_service_{download_service} {} + + void destroy(const HttpRequestPtr& req, + std::function&& callback); + + private: + std::shared_ptr download_service_; +}; diff --git a/engine/main.cc b/engine/main.cc index 186e64c18..2667c1202 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -5,6 +5,7 @@ #include "controllers/engines.h" #include "controllers/events.h" #include "controllers/models.h" +#include "controllers/process_manager.h" #include "cortex-common/cortexpythoni.h" #include "services/model_service.h" #include "utils/archive_utils.h" @@ -92,10 +93,12 @@ void RunServer() { auto engine_ctl = std::make_shared(engine_service); auto model_ctl = std::make_shared(model_service); auto event_ctl = std::make_shared(event_queue_ptr); + auto pm_ctl = std::make_shared(download_service); drogon::app().registerController(engine_ctl); drogon::app().registerController(model_ctl); drogon::app().registerController(event_ctl); + drogon::app().registerController(pm_ctl); LOG_INFO << "Server started, listening at: " << config.apiServerHost << ":" << config.apiServerPort; diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index e1964acc8..a60f4b9e7 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -4,9 +4,9 @@ #include #include #include +#include #include #include -#include #include "download_service.h" #include "utils/format_utils.h" #include "utils/logging_utils.h" @@ -21,8 +21,8 @@ #endif namespace { -size_t WriteCallback(void* ptr, size_t size, size_t nmemb, FILE* stream) { - size_t written = fwrite(ptr, size, nmemb, stream); +size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { + size_t written = fwrite(ptr, size, nmemb, (FILE*)userdata); return written; } } // namespace @@ -68,7 +68,7 @@ cpp::result DownloadService::AddDownloadTask( bool has_task_done = false; for (const auto& item : task.items) { CLI_LOG("Start downloading: " + item.localPath.filename().string()); - auto result = Download(task.id, item, true); + auto result = Download(task.id, item); if (result.has_error()) { dl_err_msg = result.error(); break; @@ -112,62 +112,9 @@ cpp::result DownloadService::GetFileSize( return content_length; } -cpp::result DownloadService::AddAsyncDownloadTask( - DownloadTask& task, - std::optional callback) noexcept { - - if (std::find(download_task_list_.begin(), download_task_list_.end(), - task.id) != download_task_list_.end()) { - return cpp::fail("Download task already exists: " + task.id); - } - - download_task_list_.push_back(task.id); - download_task_map_.insert({task.id, task}); - - { - // verify download task - auto result = VerifyDownloadTask(task); - if (result.has_error()) { - CleanUp(task.id); - return cpp::fail(result.error()); - } - } - - auto execute_download_async = [&, task, callback]() { - active_download_task_id_ = task.id; - std::optional err_msg = std::nullopt; - for (const auto& item : task.items) { - active_download_item_id_ = item.id; - CTL_INF("Start downloading: " + item.localPath.filename().string()); - auto result = Download(task.id, item, false); - if (result.has_error()) { - err_msg = result.error(); - break; - } - } - - if (err_msg.has_value()) { - CTL_ERR(err_msg.value()); - CleanUp(task.id); - return; - } - - if (callback.has_value()) { - CTL_INF("Download success, executing post download lambda!"); - callback.value()(task); - } - CleanUp(task.id); - }; - - std::thread t(execute_download_async); - t.detach(); - - return true; -} - cpp::result DownloadService::Download( - const std::string& download_id, const DownloadItem& download_item, - bool allow_resume) noexcept { + const std::string& download_id, + const DownloadItem& download_item) noexcept { CTL_INF("Absolute file output: " << download_item.localPath.string()); CURL* curl; @@ -180,7 +127,7 @@ cpp::result DownloadService::Download( } std::string mode = "wb"; - if (allow_resume && std::filesystem::exists(download_item.localPath) && + if (std::filesystem::exists(download_item.localPath) && download_item.bytes.has_value()) { curl_off_t existing_file_size = GetLocalFileSize(download_item.localPath); if (existing_file_size == -1) { @@ -235,9 +182,6 @@ cpp::result DownloadService::Download( curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_PROGRESSFUNCTION, ProgressCallback); - curl_easy_setopt(curl, CURLOPT_PROGRESSDATA, this); - if (mode == "ab") { auto local_file_size = GetLocalFileSize(download_item.localPath); if (local_file_size != -1) { @@ -276,16 +220,167 @@ curl_off_t DownloadService::GetLocalFileSize( return file_size; } -void DownloadService::CleanUp(const std::string& task_id) { - CTL_INF("Cleaning up download task: " << task_id); - // TODO: might need to be wrap with mutex - // remove from task list - download_task_list_.erase(std::remove(download_task_list_.begin(), - download_task_list_.end(), task_id), - download_task_list_.end()); - // remove from task map - download_task_map_.erase(task_id); - - active_download_task_id_ = std::nullopt; - active_download_item_id_ = std::nullopt; +void DownloadService::WorkerThread() { + while (!stop_flag_) { + DownloadTask task; + { + std::unique_lock lock(queue_mutex_); + queue_condition_.wait( + lock, [this] { return !task_queue_.empty() || stop_flag_; }); + if (stop_flag_) { + CTL_INF("Stopping download service.."); + break; + } + if (!task_queue_.empty()) { + task = std::move(task_queue_.front()); + task_queue_.pop(); + } + } + ProcessTask(task); + } +} + +void DownloadService::ProcessTask(const DownloadTask& task) { + CTL_INF("Processing task: " + task.id); + std::vector task_handles; + + for (auto& item : task.items) { + auto handle = curl_easy_init(); + if (handle == nullptr) { + // skip the task + CTL_ERR("Failed to init curl!"); + return; + } + + FILE* file; + file = fopen(item.localPath.string().c_str(), "wb"); + if (!file) { + CTL_ERR("Failed to open output file " + item.localPath.string()); + return; + } + + curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str()); + curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(handle, CURLOPT_WRITEDATA, file); + curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L); + curl_multi_add_handle(multi_handle_, handle); + task_handles.push_back(handle); + CTL_INF("Adding item to multi curl: " + item.ToString()); + } + + int still_running = 0; + bool is_terminated = false; + do { + curl_multi_perform(multi_handle_, &still_running); + curl_multi_wait(multi_handle_, NULL, 0, MAX_WAIT_MSECS, NULL); + + if (IsTaskTerminated(task.id) || stop_flag_) { + CTL_INF("IsTaskTerminated " + std::to_string(IsTaskTerminated(task.id))); + CTL_INF("stop_flag_ " + std::to_string(stop_flag_)); + + is_terminated = true; + break; + } + } while (still_running); + + ProcessCompletedTransfers(); + + for (auto handle : task_handles) { + curl_multi_remove_handle(multi_handle_, handle); + curl_easy_cleanup(handle); + } + + if (!is_terminated) { + RemoveTaskFromStopList(task.id); + CTL_INF("Executing callback.."); + ExecuteCallback(task); + } +} + +cpp::result DownloadService::StopTask( + const std::string& task_id) { + std::lock_guard lock(stop_mutex_); + + tasks_to_stop_.insert(task_id); + CTL_INF("Added task to stop list: " << task_id); + return {}; +} + +void DownloadService::ProcessCompletedTransfers() { + CURLMsg* msg; + int remaining_msg_count; + + while ((msg = curl_multi_info_read(multi_handle_, &remaining_msg_count))) { + if (msg->msg == CURLMSG_DONE) { + auto handle = msg->easy_handle; + + auto return_code = msg->data.result; + char* url; + curl_easy_getinfo(handle, CURLINFO_EFFECTIVE_URL, &url); + if (return_code != CURLE_OK) { + CTL_ERR("Download failed for: " << url << " - " + << curl_easy_strerror(return_code)); + continue; + } + + auto http_status_code = 0; + curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &http_status_code); + if (http_status_code == 200) { + CTL_INF("Download completed successfully for: " << url); + } else { + CTL_ERR("Download failed for: " << url << " - HTTP status code: " + << http_status_code); + } + } + } +} + +cpp::result DownloadService::AddTask( + DownloadTask& task, std::function callback) { + auto validate_result = VerifyDownloadTask(task); + if (validate_result.has_error()) { + return cpp::fail(validate_result.error()); + } + + { + std::lock_guard lock(callbacks_mutex_); + callbacks_[task.id] = std::move(callback); + } + + { + std::lock_guard lock(queue_mutex_); + task_queue_.push(task); + CTL_INF("Task added to queue: " << task.id); + } + + queue_condition_.notify_one(); + return task; +} + +bool DownloadService::IsTaskTerminated(const std::string& task_id) { + // can use shared mutex lock here? + std::lock_guard lock(stop_mutex_); + return tasks_to_stop_.find(task_id) != tasks_to_stop_.end(); +} + +void DownloadService::RemoveTaskFromStopList(const std::string& task_id) { + std::lock_guard lock(stop_mutex_); + tasks_to_stop_.erase(task_id); +} + +void DownloadService::ExecuteCallback(const DownloadTask& task) { + std::lock_guard lock(callbacks_mutex_); + auto it = callbacks_.find(task.id); + if (it != callbacks_.end()) { + it->second(task); + callbacks_.erase(it); + } +} + +cpp::result DownloadService::Destroy() { + // CTL_INF("Destroying download service.."); + stop_flag_ = true; + queue_condition_.notify_one(); + // CTL_INF("Destroying download service.. notified"); + return {}; } diff --git a/engine/services/download_service.h b/engine/services/download_service.h index aab3473c2..211511913 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -5,12 +5,18 @@ #include #include #include -#include +#include +#include +#include #include "common/event.h" +#include "utils/logging_utils.h" #include "utils/result.hpp" class DownloadService { public: + using OnDownloadTaskSuccessfully = + std::function; + using DownloadEventType = cortex::event::DownloadEventType; using DownloadEvent = cortex::event::DownloadEvent; using EventType = cortex::event::EventType; @@ -21,19 +27,41 @@ class DownloadService { explicit DownloadService() = default; explicit DownloadService(std::shared_ptr event_queue) - : event_queue_{event_queue} {}; + : event_queue_{event_queue} { + curl_global_init(CURL_GLOBAL_ALL); + stop_flag_ = false; + multi_handle_ = curl_multi_init(); + worker_thread_ = std::thread(&DownloadService::WorkerThread, this); + }; + + ~DownloadService() { + if (event_queue_ != nullptr) { + stop_flag_ = true; + queue_condition_.notify_one(); + + CTL_INF("DownloadService is being destroyed."); + curl_multi_cleanup(multi_handle_); + curl_global_cleanup(); + + worker_thread_.join(); + CTL_INF("DownloadService is destroyed.") + } + } - using OnDownloadTaskSuccessfully = - std::function; + /** + * Adding new download task to the queue. Asynchronously. This function should + * be used by HTTP API. + */ + cpp::result AddTask( + DownloadTask& task, std::function callback); + /** + * Start download task synchronously. + */ cpp::result AddDownloadTask( DownloadTask& task, std::optional callback = std::nullopt) noexcept; - cpp::result AddAsyncDownloadTask( - DownloadTask& task, std::optional callback = - std::nullopt) noexcept; - /** * Getting file size for a provided url. Can be used to validating the download url. * @@ -42,28 +70,53 @@ class DownloadService { cpp::result GetFileSize( const std::string& url) const noexcept; + cpp::result StopTask(const std::string& task_id); + + cpp::result Destroy(); + private: cpp::result VerifyDownloadTask( DownloadTask& task) const noexcept; - cpp::result Download(const std::string& download_id, - const DownloadItem& download_item, - bool allow_resume) noexcept; + cpp::result Download( + const std::string& download_id, + const DownloadItem& download_item) noexcept; curl_off_t GetLocalFileSize(const std::filesystem::path& path) const; std::shared_ptr event_queue_; - std::vector download_task_list_; - std::unordered_map download_task_map_; + CURLM* multi_handle_; + std::thread worker_thread_; + std::atomic stop_flag_; - std::optional active_download_task_id_; - std::optional active_download_item_id_; + // task queue + std::queue task_queue_; + std::mutex queue_mutex_; + std::condition_variable queue_condition_; - /** - * Invoked when download is completed (both failed or success) - */ - void CleanUp(const std::string& task_id); + // stop tasks + std::unordered_set tasks_to_stop_; + std::mutex stop_mutex_; + + // callbacks + std::unordered_map> + callbacks_; + std::mutex callbacks_mutex_; + + void WorkerThread(); + + void ProcessCompletedTransfers(); + + void ProcessTask(const DownloadTask& task); + + bool IsTaskTerminated(const std::string& task_id); + + void RemoveTaskFromStopList(const std::string& task_id); + + void ExecuteCallback(const DownloadTask& task); + + constexpr static auto MAX_WAIT_MSECS = 1000; static int ProgressCallback(void* ptr, double dltotal, double dlnow, double ultotal, double ulnow) { @@ -72,27 +125,10 @@ class DownloadService { return 0; } - auto active_task_id = service->active_download_task_id_; - auto active_item_id = service->active_download_item_id_; - if (!active_task_id.has_value() || !active_item_id.has_value()) { - return 0; - } - - auto task = service->download_task_map_[active_task_id.value()]; - - // loop through download items, find the active one and update it - for (auto& item : task.items) { - if (item.id == active_item_id.value()) { - item.downloadedBytes = dlnow; - item.bytes = dltotal; - break; - } - } - - service->event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, - .download_task_ = task}); + // service->event_queue_->enqueue( + // EventType::DownloadEvent, + // DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, + // .download_task_ = task}); return 0; } }; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index e20a13bb5..ddfc4a0c8 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -109,14 +109,14 @@ cpp::result GetDownloadTask( } // namespace cpp::result ModelService::DownloadModel( - const std::string& input, bool async) { + const std::string& input) { if (input.empty()) { return cpp::fail( "Input must be Cortex Model Hub handle or HuggingFace url!"); } if (string_utils::StartsWith(input, "https://")) { - return HandleUrl(input, async); + return HandleUrl(input); } if (input.find(":") != std::string::npos) { @@ -124,7 +124,7 @@ cpp::result ModelService::DownloadModel( if (parsed.size() != 2) { return cpp::fail("Invalid model handle: " + input); } - return DownloadModelFromCortexso(parsed[0], parsed[1], false); + return DownloadModelFromCortexso(parsed[0], parsed[1]); } if (input.find("/") != std::string::npos) { @@ -139,8 +139,7 @@ cpp::result ModelService::DownloadModel( return HandleCortexsoModel(model_name); } - return DownloadHuggingFaceGgufModel(author, model_name, std::nullopt, - async); + return DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); } return HandleCortexsoModel(input); @@ -193,7 +192,7 @@ cpp::result ModelService::HandleCortexsoModel( CLI_LOG("Selected: " << selection.value()); auto branch_name = selection.value().substr(modelName.size() + 1); - return DownloadModelFromCortexso(modelName, branch_name, false); + return DownloadModelFromCortexso(modelName, branch_name); } std::optional ModelService::GetDownloadedModel( @@ -222,6 +221,67 @@ std::optional ModelService::GetDownloadedModel( return std::nullopt; } +cpp::result ModelService::HandleDownloadUrlAsync( + const std::string& url) { + auto url_obj = url_parser::FromUrlString(url); + + if (url_obj.host == kHuggingFaceHost) { + if (url_obj.pathParams[2] == "blob") { + url_obj.pathParams[2] = "resolve"; + } + } + auto author{url_obj.pathParams[0]}; + auto model_id{url_obj.pathParams[1]}; + auto file_name{url_obj.pathParams.back()}; + + if (author == "cortexso") { + return DownloadModelFromCortexsoAsync(model_id); + } + + if (url_obj.pathParams.size() < 5) { + return cpp::fail("Invalid url: " + url); + } + + std::string huggingFaceHost{kHuggingFaceHost}; + std::string unique_model_id{author + ":" + model_id + ":" + file_name}; + + cortex::db::Models modellist_handler; + auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + + if (model_entry.has_value()) { + CLI_LOG("Model already downloaded: " << unique_model_id); + return cpp::fail("Please delete the model before downloading again"); + } + + auto local_path{file_manager_utils::GetModelsContainerPath() / + "huggingface.co" / author / model_id / file_name}; + + try { + std::filesystem::create_directories(local_path.parent_path()); + } catch (const std::filesystem::filesystem_error& e) { + // if file exist, remove it + std::filesystem::remove(local_path.parent_path()); + std::filesystem::create_directories(local_path.parent_path()); + } + + auto download_url = url_parser::FromUrl(url_obj); + // this assume that the model being downloaded is a single gguf file + auto downloadTask{DownloadTask{.id = model_id, + .type = DownloadType::Model, + .items = {DownloadItem{ + .id = unique_model_id, + .downloadUrl = download_url, + .localPath = local_path, + }}}}; + + auto on_finished = [&](const DownloadTask& finishedTask) { + auto gguf_download_item = finishedTask.items[0]; + ParseGguf(gguf_download_item, author); + }; + + return download_service_->AddTask(downloadTask, on_finished); +} + cpp::result ModelService::HandleUrl( const std::string& url, bool async) { auto url_obj = url_parser::FromUrlString(url); @@ -243,7 +303,7 @@ cpp::result ModelService::HandleUrl( if (url_obj.pathParams.size() < 2) { return cpp::fail("Invalid url: " + url); } - return DownloadHuggingFaceGgufModel(author, model_id, std::nullopt, async); + return DownloadHuggingFaceGgufModel(author, model_id, std::nullopt); } std::string huggingFaceHost{kHuggingFaceHost}; @@ -284,8 +344,7 @@ cpp::result ModelService::HandleUrl( }; if (async) { - auto result = - download_service_->AddAsyncDownloadTask(downloadTask, on_finished); + auto result = download_service_->AddTask(downloadTask, on_finished); if (result.has_error()) { CTL_ERR(result.error()); @@ -304,8 +363,9 @@ cpp::result ModelService::HandleUrl( } } -cpp::result ModelService::DownloadModelFromCortexso( - const std::string& name, const std::string& branch, bool async) { +cpp::result +ModelService::DownloadModelFromCortexsoAsync(const std::string& name, + const std::string& branch) { auto download_task = GetDownloadTask(name, branch); if (download_task.has_error()) { @@ -346,28 +406,78 @@ cpp::result ModelService::DownloadModelFromCortexso( .branch_name = branch, .path_to_model_yaml = rel.string(), .model_alias = model_id}; - modellist_utils_obj.AddModelEntry(model_entry); + auto result = modellist_utils_obj.AddModelEntry(model_entry); + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } }; - auto result = async ? download_service_->AddAsyncDownloadTask( - download_task.value(), on_finished) - : download_service_->AddDownloadTask( - download_task.value(), on_finished); + return download_service_->AddTask(download_task.value(), on_finished); +} + +cpp::result ModelService::DownloadModelFromCortexso( + const std::string& name, const std::string& branch) { + + auto download_task = GetDownloadTask(name, branch); + if (download_task.has_error()) { + return cpp::fail(download_task.error()); + } + + std::string model_id{name + ":" + branch}; + auto on_finished = [&, model_id](const DownloadTask& finishedTask) { + const DownloadItem* model_yml_item = nullptr; + auto need_parse_gguf = true; + + for (const auto& item : finishedTask.items) { + if (item.localPath.filename().string() == "model.yml") { + model_yml_item = &item; + } + } + + if (model_yml_item == nullptr) { + CTL_WRN("model.yml not found in the downloaded files for " + model_id); + return; + } + auto url_obj = url_parser::FromUrlString(model_yml_item->downloadUrl); + CTL_INF("Adding model to modellist with branch: " << branch); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(model_yml_item->localPath.string()); + auto mc = yaml_handler.GetModelConfig(); + mc.model = model_id; + yaml_handler.UpdateModelConfig(mc); + yaml_handler.WriteYamlFile(model_yml_item->localPath.string()); + + auto rel = + file_manager_utils::ToRelativeCortexDataPath(model_yml_item->localPath); + CTL_INF("path_to_model_yaml: " << rel.string()); + cortex::db::Models modellist_utils_obj; + cortex::db::ModelEntry model_entry{.model = model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = rel.string(), + .model_alias = model_id}; + auto result = modellist_utils_obj.AddModelEntry(model_entry); + if (result.has_error()) { + CTL_ERR("Error adding model to modellist: " + result.error()); + } + }; + + auto result = + download_service_->AddDownloadTask(download_task.value(), on_finished); if (result.has_error()) { return cpp::fail(result.error()); } else if (result && result.value()) { CLI_LOG("Model " << model_id << " downloaded successfully!") + return model_id; } - - return model_id; + return cpp::fail("Failed to download model " + model_id); } cpp::result -ModelService::DownloadHuggingFaceGgufModel(const std::string& author, - const std::string& modelName, - std::optional fileName, - bool async) { +ModelService::DownloadHuggingFaceGgufModel( + const std::string& author, const std::string& modelName, + std::optional fileName) { auto repo_info = huggingface_utils::GetHuggingFaceModelRepoInfo(author, modelName); @@ -491,7 +601,7 @@ cpp::result ModelService::StartModel( } else { CTL_ERR("Model failed to load with status code: " << res->status); return cpp::fail("Model failed to load with status code: " + - res->status); + std::to_string(res->status)); } } else { auto err = res.error(); @@ -540,7 +650,7 @@ cpp::result ModelService::StopModel( } else { CTL_ERR("Model failed to unload with status code: " << res->status); return cpp::fail("Model failed to unload with status code: " + - res->status); + std::to_string(res->status)); } } else { auto err = res.error(); @@ -589,7 +699,7 @@ cpp::result ModelService::GetModelStatus( CTL_INF("Model failed to get model status with status code: " << res->status); return cpp::fail("Model failed to get model status with status code: " + - res->status); + std::to_string(res->status)); } } else { auto err = res.error(); @@ -601,3 +711,9 @@ cpp::result ModelService::GetModelStatus( "': " + e.what()); } } + +cpp::result ModelService::AbortDownloadModel( + const std::string& task_id) { + auto result = download_service_->StopTask(task_id); + return result; +} diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 319260584..e0c828420 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -15,12 +15,15 @@ class ModelService { /** * Return model id if download successfully */ - cpp::result DownloadModel(const std::string& input, - bool async = false); + cpp::result DownloadModel(const std::string& input); + + cpp::result AbortDownloadModel(const std::string& task_id); cpp::result DownloadModelFromCortexso( - const std::string& name, const std::string& branch = "main", - bool async = false); + const std::string& name, const std::string& branch = "main"); + + cpp::result DownloadModelFromCortexsoAsync( + const std::string& name, const std::string& branch = "main"); std::optional GetDownloadedModel( const std::string& modelId) const; @@ -43,13 +46,16 @@ class ModelService { cpp::result HandleUrl(const std::string& url, bool async = false); + cpp::result HandleDownloadUrlAsync( + const std::string& url); + private: /** * Handle downloading model which have following pattern: author/model_name */ cpp::result DownloadHuggingFaceGgufModel( const std::string& author, const std::string& modelName, - std::optional fileName, bool async = false); + std::optional fileName); /** * Handling cortexso models. Will look through cortexso's HF repository and From afeda206f61fdd2aa745b364c7ee952f9ee64c65 Mon Sep 17 00:00:00 2001 From: James Date: Tue, 15 Oct 2024 19:13:13 +0700 Subject: [PATCH 3/6] update progress --- engine/common/event.h | 2 +- engine/services/download_service.cc | 37 +++++++++++++++--- engine/services/download_service.h | 58 +++++++++++++++++++++++------ 3 files changed, 80 insertions(+), 17 deletions(-) diff --git a/engine/common/event.h b/engine/common/event.h index 0fdce7770..fe68bd04e 100644 --- a/engine/common/event.h +++ b/engine/common/event.h @@ -34,7 +34,7 @@ std::string DownloadEventTypeToString(DownloadEventType type) { case DownloadEventType::DownloadStarted: return "DownloadStarted"; case DownloadEventType::DownloadStopped: - return "DownloadPaused"; + return "DownloadStopped"; case DownloadEventType::DownloadUpdated: return "DownloadUpdated"; case DownloadEventType::DownloadSuccess: diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index a60f4b9e7..baafaf8cc 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -2,14 +2,12 @@ #include #include #include -#include #include #include #include #include #include "download_service.h" #include "utils/format_utils.h" -#include "utils/logging_utils.h" #include "utils/result.hpp" #ifdef _WIN32 @@ -240,7 +238,7 @@ void DownloadService::WorkerThread() { } } -void DownloadService::ProcessTask(const DownloadTask& task) { +void DownloadService::ProcessTask(DownloadTask& task) { CTL_INF("Processing task: " + task.id); std::vector task_handles; @@ -259,15 +257,30 @@ void DownloadService::ProcessTask(const DownloadTask& task) { return; } + downloading_data_ = std::make_shared(DownloadingData{ + .download_item = &item, + .download_task = &task, + .event_queue = event_queue_.get(), + }); + curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str()); curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback); curl_easy_setopt(handle, CURLOPT_WRITEDATA, file); curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(handle, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(handle, CURLOPT_XFERINFOFUNCTION, ProgressCallback); + curl_easy_setopt(handle, CURLOPT_XFERINFODATA, downloading_data_.get()); + curl_multi_add_handle(multi_handle_, handle); task_handles.push_back(handle); CTL_INF("Adding item to multi curl: " + item.ToString()); } + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadStarted, + .download_task_ = task}); + int still_running = 0; bool is_terminated = false; do { @@ -288,12 +301,27 @@ void DownloadService::ProcessTask(const DownloadTask& task) { for (auto handle : task_handles) { curl_multi_remove_handle(multi_handle_, handle); curl_easy_cleanup(handle); + downloading_data_.reset(); + } + + // if terminate by API calling and not from process stopping, we emit + // DownloadStopped event + if (is_terminated && !stop_flag_) { + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadStopped, + .download_task_ = task}); } if (!is_terminated) { RemoveTaskFromStopList(task.id); CTL_INF("Executing callback.."); ExecuteCallback(task); + + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadSuccess, + .download_task_ = task}); } } @@ -378,9 +406,8 @@ void DownloadService::ExecuteCallback(const DownloadTask& task) { } cpp::result DownloadService::Destroy() { - // CTL_INF("Destroying download service.."); + CTL_INF("Destroying download service.."); stop_flag_ = true; queue_condition_.notify_one(); - // CTL_INF("Destroying download service.. notified"); return {}; } diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 211511913..1ccec6fdf 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -44,7 +45,7 @@ class DownloadService { curl_global_cleanup(); worker_thread_.join(); - CTL_INF("DownloadService is destroyed.") + CTL_INF("DownloadService is destroyed."); } } @@ -75,6 +76,12 @@ class DownloadService { cpp::result Destroy(); private: + struct DownloadingData { + DownloadItem* download_item; + DownloadTask* download_task; + EventQueue* event_queue; + }; + cpp::result VerifyDownloadTask( DownloadTask& task) const noexcept; @@ -104,11 +111,16 @@ class DownloadService { callbacks_; std::mutex callbacks_mutex_; + std::shared_ptr downloading_data_; + + // active task that being downloaded atm + std::unordered_map active_tasks_; + void WorkerThread(); void ProcessCompletedTransfers(); - void ProcessTask(const DownloadTask& task); + void ProcessTask(DownloadTask& task); bool IsTaskTerminated(const std::string& task_id); @@ -118,17 +130,41 @@ class DownloadService { constexpr static auto MAX_WAIT_MSECS = 1000; - static int ProgressCallback(void* ptr, double dltotal, double dlnow, - double ultotal, double ulnow) { - auto service = static_cast(ptr); - if (service->event_queue_ == nullptr) { - return 0; + static int ProgressCallback(void* ptr, curl_off_t dltotal, curl_off_t dlnow, + curl_off_t ultotal, curl_off_t ulnow) { + auto* downloading_data = static_cast(ptr); + auto& event_queue = *downloading_data->event_queue; + auto& download_item = *downloading_data->download_item; + auto& download_task = *downloading_data->download_task; + + // update the download task with corresponding download item + for (auto& item : download_task.items) { + if (item.id == download_item.id) { + item.downloadedBytes = dlnow; + item.bytes = dltotal; + break; + } + } + + // Check if one second has passed since the last event + static auto last_event_time = std::chrono::steady_clock::now(); + auto current_time = std::chrono::steady_clock::now(); + auto time_since_last_event = + std::chrono::duration_cast(current_time - + last_event_time) + .count(); + + // throttle event by 1 sec + if (time_since_last_event >= 1000) { + event_queue.enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, + .download_task_ = download_task}); + + // Update the last event time + last_event_time = current_time; } - // service->event_queue_->enqueue( - // EventType::DownloadEvent, - // DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, - // .download_task_ = task}); return 0; } }; From a813b5c054543149ac87ff41cc4880ab0837d936 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 16 Oct 2024 01:39:09 +0700 Subject: [PATCH 4/6] fix crash issue when download with cortexso model --- engine/services/download_service.cc | 14 +++++++------- engine/services/download_service.h | 6 ++---- engine/services/model_service.cc | 30 ++++++++++++----------------- engine/services/model_service.h | 3 +-- 4 files changed, 22 insertions(+), 31 deletions(-) diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index baafaf8cc..920fd9565 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -242,6 +242,12 @@ void DownloadService::ProcessTask(DownloadTask& task) { CTL_INF("Processing task: " + task.id); std::vector task_handles; + downloading_data_ = std::make_shared(DownloadingData{ + .item_id = "", + .download_task = &task, + .event_queue = event_queue_.get(), + }); + for (auto& item : task.items) { auto handle = curl_easy_init(); if (handle == nullptr) { @@ -256,13 +262,7 @@ void DownloadService::ProcessTask(DownloadTask& task) { CTL_ERR("Failed to open output file " + item.localPath.string()); return; } - - downloading_data_ = std::make_shared(DownloadingData{ - .download_item = &item, - .download_task = &task, - .event_queue = event_queue_.get(), - }); - + downloading_data_->item_id = item.id; curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str()); curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback); curl_easy_setopt(handle, CURLOPT_WRITEDATA, file); diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 1ccec6fdf..eca1da332 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -77,7 +76,7 @@ class DownloadService { private: struct DownloadingData { - DownloadItem* download_item; + std::string item_id; DownloadTask* download_task; EventQueue* event_queue; }; @@ -134,12 +133,11 @@ class DownloadService { curl_off_t ultotal, curl_off_t ulnow) { auto* downloading_data = static_cast(ptr); auto& event_queue = *downloading_data->event_queue; - auto& download_item = *downloading_data->download_item; auto& download_task = *downloading_data->download_task; // update the download task with corresponding download item for (auto& item : download_task.items) { - if (item.id == download_item.id) { + if (item.id == downloading_data->item_id) { item.downloadedBytes = dlnow; item.bytes = dltotal; break; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index ddfc4a0c8..201622104 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -279,11 +279,12 @@ cpp::result ModelService::HandleDownloadUrlAsync( ParseGguf(gguf_download_item, author); }; + downloadTask.id = unique_model_id; return download_service_->AddTask(downloadTask, on_finished); } cpp::result ModelService::HandleUrl( - const std::string& url, bool async) { + const std::string& url) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { @@ -343,24 +344,14 @@ cpp::result ModelService::HandleUrl( ParseGguf(gguf_download_item, author); }; - if (async) { - auto result = download_service_->AddTask(downloadTask, on_finished); - - if (result.has_error()) { - CTL_ERR(result.error()); - return cpp::fail(result.error()); - } - return unique_model_id; - } else { - auto result = download_service_->AddDownloadTask(downloadTask, on_finished); - if (result.has_error()) { - CTL_ERR(result.error()); - return cpp::fail(result.error()); - } else if (result && result.value()) { - CLI_LOG("Model " << model_id << " downloaded successfully!") - } - return unique_model_id; + auto result = download_service_->AddDownloadTask(downloadTask, on_finished); + if (result.has_error()) { + CTL_ERR(result.error()); + return cpp::fail(result.error()); + } else if (result && result.value()) { + CLI_LOG("Model " << model_id << " downloaded successfully!") } + return unique_model_id; } cpp::result @@ -412,6 +403,9 @@ ModelService::DownloadModelFromCortexsoAsync(const std::string& name, } }; + auto task = download_task.value(); + task.id = model_id; + return download_service_->AddTask(download_task.value(), on_finished); } diff --git a/engine/services/model_service.h b/engine/services/model_service.h index e0c828420..ca33a7796 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -43,8 +43,7 @@ class ModelService { cpp::result GetModelStatus( const std::string& host, int port, const std::string& model_handle); - cpp::result HandleUrl(const std::string& url, - bool async = false); + cpp::result HandleUrl(const std::string& url); cpp::result HandleDownloadUrlAsync( const std::string& url); From 43549bc5acad5eb8ce8d98f5e1bbb3db1b32bd64 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 16 Oct 2024 08:53:42 +0700 Subject: [PATCH 5/6] Remove redundant code --- engine/services/download_service.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index 920fd9565..a1905eefb 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -229,10 +229,8 @@ void DownloadService::WorkerThread() { CTL_INF("Stopping download service.."); break; } - if (!task_queue_.empty()) { - task = std::move(task_queue_.front()); - task_queue_.pop(); - } + task = std::move(task_queue_.front()); + task_queue_.pop(); } ProcessTask(task); } From 9da98d66c38976b309ec35073b3cca3f6dceabb0 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 16 Oct 2024 09:02:27 +0700 Subject: [PATCH 6/6] Update comments --- engine/services/download_service.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/engine/services/download_service.h b/engine/services/download_service.h index eca1da332..605c517d0 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -43,7 +43,9 @@ class DownloadService { curl_multi_cleanup(multi_handle_); curl_global_cleanup(); - worker_thread_.join(); + if (worker_thread_.joinable()) { + worker_thread_.join(); + } CTL_INF("DownloadService is destroyed."); } } @@ -112,9 +114,6 @@ class DownloadService { std::shared_ptr downloading_data_; - // active task that being downloaded atm - std::unordered_map active_tasks_; - void WorkerThread(); void ProcessCompletedTransfers();