diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 2424513eb..1f712d10c 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -36,8 +36,13 @@ bool EngineInstallCmd::Exec(const std::string& engine, DownloadProgress dp; dp.Connect(host_, port_); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, [&dp, &engine] { - return dp.Handle(DownloadType::Engine); + auto dp_res = std::async(std::launch::deferred, [&dp] { + bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (need_cuda_download) { + return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + } else { + return dp.Handle({DownloadType::Engine}); + } }); auto versions_url = url_parser::Url{ @@ -133,12 +138,6 @@ bool EngineInstallCmd::Exec(const std::string& engine, if (!dp_res.get()) return false; - bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); - if (check_cuda_download) { - if (!dp.Handle(DownloadType::CudaToolkit)) - return false; - } - CLI_LOG("Engine " << engine << " downloaded successfully!") return true; } @@ -147,8 +146,14 @@ bool EngineInstallCmd::Exec(const std::string& engine, DownloadProgress dp; dp.Connect(host_, port_); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, - [&dp] { return dp.Handle(DownloadType::Engine); }); + auto dp_res = std::async(std::launch::deferred, [&dp] { + bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (need_cuda_download) { + return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + } else { + return dp.Handle({DownloadType::Engine}); + } + }); auto install_url = url_parser::Url{ .protocol = "http", @@ -183,12 +188,6 @@ bool EngineInstallCmd::Exec(const std::string& engine, if (!dp_res.get()) return false; - bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); - if (check_cuda_download) { - if (!dp.Handle(DownloadType::CudaToolkit)) - return false; - } - CLI_LOG("Engine " << engine << " downloaded successfully!") return true; } diff --git a/engine/cli/commands/engine_update_cmd.cc b/engine/cli/commands/engine_update_cmd.cc index ee3a526c3..9717ddb15 100644 --- a/engine/cli/commands/engine_update_cmd.cc +++ b/engine/cli/commands/engine_update_cmd.cc @@ -24,8 +24,13 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port, DownloadProgress dp; dp.Connect(host, port); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, [&dp, &engine] { - return dp.Handle(DownloadType::Engine); + auto dp_res = std::async(std::launch::deferred, [&dp] { + bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (need_cuda_download) { + return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + } else { + return dp.Handle({DownloadType::Engine}); + } }); auto update_url = url_parser::Url{ @@ -48,12 +53,6 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port, if (!dp_res.get()) return false; - bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); - if (check_cuda_download) { - if (!dp.Handle(DownloadType::CudaToolkit)) - return false; - } - CLI_LOG("Engine " << engine << " updated successfully!") return true; } diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 5793c2e09..d769b667a 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -143,7 +143,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, reinterpret_cast(console_ctrl_handler), true); #endif dp.Connect(host, port); - if (!dp.Handle(DownloadType::Model)) + if (!dp.Handle({DownloadType::Model})) return std::nullopt; if (force_stop) return std::nullopt; diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index 017d71d0e..e085a660e 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -34,7 +34,8 @@ bool DownloadProgress::Connect(const std::string& host, int port) { return true; } -bool DownloadProgress::Handle(const DownloadType& event_type) { +bool DownloadProgress::Handle( + const std::unordered_set& event_type) { assert(!!ws_); #if defined(_WIN32) HANDLE h_out = GetStdHandle(STD_OUTPUT_HANDLE); @@ -50,10 +51,14 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { } } #endif - status_ = DownloadStatus::DownloadStarted; + for (auto et : event_type) { + status_[et] = DownloadStatus::DownloadStarted; + } std::unique_ptr> bars; - std::vector> items; + std::unordered_map>> + items; indicators::show_console_cursor(false); auto start = std::chrono::steady_clock::now(); auto handle_message = [this, &bars, &items, event_type, @@ -78,22 +83,28 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { auto ev = cortex::event::GetDownloadEventFromJson( json_helper::ParseJsonString(message)); // Ignore other task type - if (ev.download_task_.type != event_type) { + if (event_type.find(ev.download_task_.type) == event_type.end()) { return; } auto now = std::chrono::steady_clock::now(); if (!bars) { bars = std::make_unique< indicators::DynamicProgress>(); - for (auto& i : ev.download_task_.items) { - items.emplace_back(std::make_unique( - indicators::option::BarWidth{50}, indicators::option::Start{"["}, - indicators::option::Fill{"="}, indicators::option::Lead{">"}, - indicators::option::End{"]"}, - indicators::option::PrefixText{pad_string(Repo2Engine(i.id))}, - indicators::option::ForegroundColor{indicators::Color::white}, - indicators::option::ShowRemainingTime{false})); - bars->push_back(*(items.back())); + } + for (auto& i : ev.download_task_.items) { + if (items.find(i.id) == items.end()) { + auto idx = items.size(); + items[i.id] = std::pair( + idx, + std::make_unique( + indicators::option::BarWidth{50}, + indicators::option::Start{"["}, indicators::option::Fill{"="}, + indicators::option::Lead{">"}, indicators::option::End{"]"}, + indicators::option::PrefixText{pad_string(Repo2Engine(i.id))}, + indicators::option::ForegroundColor{indicators::Color::white}, + indicators::option::ShowRemainingTime{false})); + + bars->push_back(*(items.at(i.id).second)); } } for (int i = 0; i < ev.download_task_.items.size(); i++) { @@ -113,32 +124,36 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { (total - downloaded) / bytes_per_sec); } - (*bars)[i].set_option(indicators::option::PrefixText{ - pad_string(Repo2Engine(it.id)) + - std::to_string(int(static_cast(downloaded) / total * 100)) + - '%'}); - (*bars)[i].set_progress( + (*bars)[items.at(it.id).first].set_option( + indicators::option::PrefixText{ + pad_string(Repo2Engine(it.id)) + + std::to_string( + int(static_cast(downloaded) / total * 100)) + + '%'}); + (*bars)[items.at(it.id).first].set_progress( int(static_cast(downloaded) / total * 100)); - (*bars)[i].set_option(indicators::option::PostfixText{ - time_remaining + " " + - format_utils::BytesToHumanReadable(downloaded) + "/" + - format_utils::BytesToHumanReadable(total)}); + (*bars)[items.at(it.id).first].set_option( + indicators::option::PostfixText{ + time_remaining + " " + + format_utils::BytesToHumanReadable(downloaded) + "/" + + format_utils::BytesToHumanReadable(total)}); } else if (ev.type_ == DownloadStatus::DownloadSuccess) { uint64_t total = it.bytes.value_or(std::numeric_limits::max()); - (*bars)[i].set_progress(100); + (*bars)[items.at(it.id).first].set_progress(100); auto total_str = format_utils::BytesToHumanReadable(total); - (*bars)[i].set_option(indicators::option::PostfixText{ - "00m:00s " + total_str + "/" + total_str}); - (*bars)[i].set_option(indicators::option::PrefixText{ - pad_string(Repo2Engine(it.id)) + "100%"}); - (*bars)[i].set_progress(100); + (*bars)[items.at(it.id).first].set_option( + indicators::option::PostfixText{"00m:00s " + total_str + "/" + + total_str}); + (*bars)[items.at(it.id).first].set_option( + indicators::option::PrefixText{pad_string(Repo2Engine(it.id)) + + "100%"}); + (*bars)[items.at(it.id).first].set_progress(100); CTL_INF("Download success"); } + status_[ev.download_task_.type] = ev.type_; } - - status_ = ev.type_; }; while (ws_->getReadyState() != easywsclient::WebSocket::CLOSED && @@ -152,7 +167,9 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { SetConsoleMode(h_out, dw_original_out_mode); } #endif - if (status_ == DownloadStatus::DownloadError) - return false; + for (auto const& [_, v] : status_) { + if (v == DownloadStatus::DownloadError) + return false; + } return true; } diff --git a/engine/cli/utils/download_progress.h b/engine/cli/utils/download_progress.h index 98fe85654..6ea764ec4 100644 --- a/engine/cli/utils/download_progress.h +++ b/engine/cli/utils/download_progress.h @@ -2,6 +2,7 @@ #include #include #include +#include #include "common/event.h" #include "easywsclient.hpp" @@ -10,19 +11,25 @@ class DownloadProgress { public: bool Connect(const std::string& host, int port); - bool Handle(const DownloadType& event_type); + bool Handle(const std::unordered_set& event_type); void ForceStop() { force_stop_ = true; } private: bool should_stop() const { - return (status_ != DownloadStatus::DownloadStarted && - status_ != DownloadStatus::DownloadUpdated) || - force_stop_; + bool should_stop = true; + for (auto const& [_, v] : status_) { + should_stop &= (v == DownloadStatus::DownloadSuccess); + } + for (auto const& [_, v] : status_) { + should_stop |= (v == DownloadStatus::DownloadError || + v == DownloadStatus::DownloadStopped); + } + return should_stop || force_stop_; } private: std::unique_ptr ws_; - std::atomic status_ = DownloadStatus::DownloadStarted; + std::unordered_map> status_; std::atomic force_stop_ = false; }; \ No newline at end of file diff --git a/engine/e2e-test/test_api_model_start.py b/engine/e2e-test/test_api_model_start.py index 830d32da8..d6a98a78b 100644 --- a/engine/e2e-test/test_api_model_start.py +++ b/engine/e2e-test/test_api_model_start.py @@ -12,7 +12,7 @@ def setup_and_teardown(self): success = start_server() if not success: raise Exception("Failed to start server") - requests.post("http://localhost:3928/v1/engines/llama-cpp") + run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60) run("Delete model", ["models", "delete", "tinyllama:gguf"]) run( "Pull model", @@ -27,5 +27,7 @@ def setup_and_teardown(self): def test_models_start_should_be_successful(self): json_body = {"model": "tinyllama:gguf"} - response = requests.post("http://localhost:3928/v1/models/start", json=json_body) + response = requests.post( + "http://localhost:3928/v1/models/start", json=json_body + ) assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/e2e-test/test_api_model_stop.py b/engine/e2e-test/test_api_model_stop.py index 97bec671e..dc3b6b77b 100644 --- a/engine/e2e-test/test_api_model_stop.py +++ b/engine/e2e-test/test_api_model_stop.py @@ -1,6 +1,6 @@ import pytest import requests -from test_runner import start_server, stop_server +from test_runner import run, start_server, stop_server class TestApiModelStop: @@ -13,16 +13,18 @@ def setup_and_teardown(self): if not success: raise Exception("Failed to start server") - requests.post("http://localhost:3928/engines/llama-cpp") + run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60) yield - requests.delete("http://localhost:3928/engines/llama-cpp") + run("Uninstall engine", ["engines", "uninstall", "llama-cpp"]) # Teardown stop_server() def test_models_stop_should_be_successful(self): json_body = {"model": "tinyllama:gguf"} - response = requests.post("http://localhost:3928/v1/models/start", json=json_body) + response = requests.post( + "http://localhost:3928/v1/models/start", json=json_body + ) assert response.status_code == 200, f"status_code: {response.status_code}" response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index d125f8ef0..08e366151 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -16,6 +16,42 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { size_t written = fwrite(ptr, size, nmemb, (FILE*)userdata); return written; } + +cpp::result ProcessCompletedTransfers(CURLM* multi_handle) { + CURLMsg* msg; + int msgs_left; + while ((msg = curl_multi_info_read(multi_handle, &msgs_left))) { + if (msg->msg == CURLMSG_DONE) { + auto handle = msg->easy_handle; + auto result = msg->data.result; + + char* url = nullptr; + curl_easy_getinfo(handle, CURLINFO_EFFECTIVE_URL, &url); + + if (result != CURLE_OK) { + CTL_ERR("Transfer failed for URL: " << url << " Error: " + << curl_easy_strerror(result)); + // download failed + return cpp::fail("Transfer failed for URL: " + std::string(url) + + " Error: " + curl_easy_strerror(result)); + } else { + long response_code; + curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &response_code); + if (response_code == 200) { + CTL_INF("Transfer completed for URL: " << url); + } else { + CTL_ERR("Transfer failed with HTTP code: " << response_code + << " for URL: " << url); + // download failed + return cpp::fail("Transfer failed with HTTP code: " + + std::to_string(response_code) + + " for URL: " + std::string(url)); + } + } + } + } + return {}; +} } // namespace cpp::result DownloadService::AddDownloadTask( @@ -179,195 +215,229 @@ cpp::result DownloadService::Download( curl_easy_cleanup(curl); return true; } -void DownloadService::WorkerThread() { + +cpp::result DownloadService::StopTask( + const std::string& task_id) { + // First try to cancel in queue + CTL_INF("Stopping task: " << task_id); + auto cancelled = task_queue_.cancelTask(task_id); + if (cancelled) { + EmitTaskStopped(task_id); + return task_id; + } + CTL_INF("Not found in pending task, try to find task " + task_id + + " in active tasks"); + // Check if task is currently being processed + std::lock_guard lock(active_tasks_mutex_); + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + CTL_INF("Found task " + task_id + " in active tasks"); + it->second->status = DownloadTask::Status::Cancelled; + { + std::lock_guard lock(stop_mutex_); + tasks_to_stop_.insert(task_id); + } + EmitTaskStopped(task_id); + return task_id; + } + + CTL_WRN("Task not found"); + return cpp::fail("Task not found"); +} + +void DownloadService::InitializeWorkers() { + for (auto i = 0; i < MAX_CONCURRENT_TASKS; ++i) { + auto worker_data = std::make_unique(); + worker_data->multi_handle = curl_multi_init(); + worker_data_.push_back(std::move(worker_data)); + + worker_threads_.emplace_back([this, i]() { this->WorkerThread(i); }); + CTL_INF("Starting worker thread: " << i); + } +} + +void DownloadService::Shutdown() { + stop_flag_ = true; + task_cv_.notify_all(); // Wake up all waiting threads + + for (auto& thread : worker_threads_) { + if (thread.joinable()) { + thread.join(); + } + } + + for (auto& worker_data : worker_data_) { + curl_multi_cleanup(worker_data->multi_handle); + } + + // Clean up any remaining callbacks + std::lock_guard lock(callbacks_mutex_); + callbacks_.clear(); +} + +void DownloadService::WorkerThread(int worker_id) { + auto& worker_data = worker_data_[worker_id]; + while (!stop_flag_) { - DownloadTask task; + std::unique_lock lock(task_mutex_); + + // Wait for a task or stop signal + task_cv_.wait(lock, [this] { + auto pending_task = task_queue_.getNextPendingTask(); + return pending_task.has_value() || stop_flag_; + }); + + if (stop_flag_) { + break; + } + + auto maybe_task = task_queue_.pop(); + lock.unlock(); + + if (!maybe_task || maybe_task->status == DownloadTask::Status::Cancelled) { + continue; + } + + auto task = std::move(maybe_task.value()); + + // Register active task { - std::unique_lock lock(queue_mutex_); - queue_condition_.wait( - lock, [this] { return !task_queue_.empty() || stop_flag_; }); - if (stop_flag_) { - CTL_INF("Stopping download service.."); - break; - } - task = std::move(task_queue_.front()); - task_queue_.pop(); + std::lock_guard active_lock(active_tasks_mutex_); + active_tasks_[task.id] = std::make_shared(task); + } + + ProcessTask(task, worker_id); + + // Remove from active tasks + { + std::lock_guard active_lock(active_tasks_mutex_); + active_tasks_.erase(task.id); } - ProcessTask(task); } } -void DownloadService::ProcessTask(DownloadTask& task) { - CTL_INF("Processing task: " + task.id); +void DownloadService::ProcessTask(DownloadTask& task, int worker_id) { + auto& worker_data = worker_data_[worker_id]; std::vector> task_handles; - active_task_ = std::make_shared(task); - + task.status = DownloadTask::Status::InProgress; for (const auto& item : task.items) { auto handle = curl_easy_init(); - if (handle == nullptr) { - // skip the task + if (!handle) { CTL_ERR("Failed to init curl!"); return; } - auto file = fopen(item.localPath.string().c_str(), "wb"); if (!file) { CTL_ERR("Failed to open output file " + item.localPath.string()); + curl_easy_cleanup(handle); return; } - auto dl_data_ptr = std::make_shared(DownloadingData{ + .task_id = task.id, .item_id = item.id, .download_service = this, }); - downloading_data_map_.insert(std::make_pair(item.id, dl_data_ptr)); - - auto headers = curl_utils::GetHeaders(item.downloadUrl); - if (headers.has_value()) { - curl_slist* curl_headers = nullptr; - - for (const auto& [key, value] : headers.value()) { - auto header = key + ": " + value; - curl_headers = curl_slist_append(curl_headers, header.c_str()); - } + worker_data->downloading_data_map[item.id] = dl_data_ptr; - curl_easy_setopt(handle, CURLOPT_HTTPHEADER, curl_headers); - } - - curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str()); - curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(handle, CURLOPT_WRITEDATA, file); - curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(handle, CURLOPT_NOPROGRESS, 0L); - curl_easy_setopt(handle, CURLOPT_XFERINFOFUNCTION, ProgressCallback); - curl_easy_setopt(handle, CURLOPT_XFERINFODATA, dl_data_ptr.get()); - - curl_multi_add_handle(multi_handle_, handle); + SetUpCurlHandle(handle, item, file, dl_data_ptr.get()); + curl_multi_add_handle(worker_data->multi_handle, handle); task_handles.push_back(std::make_pair(handle, file)); - CTL_INF("Adding item to multi curl: " + item.ToString()); } - event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadStarted, - .download_task_ = task}); + EmitTaskStarted(task); - auto still_running = 0; - auto is_terminated = false; - do { - curl_multi_perform(multi_handle_, &still_running); - curl_multi_wait(multi_handle_, NULL, 0, MAX_WAIT_MSECS, NULL); + auto result = + ProcessMultiDownload(task, worker_data->multi_handle, task_handles); - if (IsTaskTerminated(task.id) || stop_flag_) { - CTL_INF("IsTaskTerminated " + std::to_string(IsTaskTerminated(task.id))); - CTL_INF("stop_flag_ " + std::to_string(stop_flag_)); - - is_terminated = true; - break; - } - } while (still_running); - - if (stop_flag_) { - CTL_INF("Download service is stopping.."); - - // try to close file - for (auto pair : task_handles) { - fclose(pair.second); - } - - active_task_.reset(); - downloading_data_map_.clear(); - return; - } - - ProcessCompletedTransfers(); - for (const auto& pair : task_handles) { - curl_multi_remove_handle(multi_handle_, pair.first); - curl_easy_cleanup(pair.first); - fclose(pair.second); + // clean up + for (auto& [handle, file] : task_handles) { + curl_multi_remove_handle(worker_data->multi_handle, handle); + curl_easy_cleanup(handle); + fclose(file); } - downloading_data_map_.clear(); - auto copied_task = *active_task_; - active_task_.reset(); - - RemoveTaskFromStopList(task.id); - // if terminate by API calling and not from process stopping, we emit - // DownloadStopped event - if (is_terminated) { - event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadStopped, - .download_task_ = copied_task}); - } else { - CTL_INF("Executing callback.."); + if (!result.has_error()) { + // if the download has error, we are not run the callback ExecuteCallback(task); - - // set all items to done - for (auto& item : copied_task.items) { - item.downloadedBytes = item.bytes; + EmitTaskCompleted(task.id); + { + std::lock_guard lock(event_emit_map_mutex); + event_emit_map_.erase(task.id); } - - event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadSuccess, - .download_task_ = copied_task}); } -} -cpp::result DownloadService::StopTask( - const std::string& task_id) { - std::lock_guard lock(stop_mutex_); - - tasks_to_stop_.insert(task_id); - CTL_INF("Added task to stop list: " << task_id); - return task_id; + worker_data->downloading_data_map.clear(); } -void DownloadService::ProcessCompletedTransfers() { - CURLMsg* msg; - int remaining_msg_count; - - while ((msg = curl_multi_info_read(multi_handle_, &remaining_msg_count))) { - if (msg->msg == CURLMSG_DONE) { - auto handle = msg->easy_handle; +cpp::result DownloadService::ProcessMultiDownload( + DownloadTask& task, CURLM* multi_handle, + const std::vector>& handles) { + auto still_running = 0; + do { + curl_multi_perform(multi_handle, &still_running); + curl_multi_wait(multi_handle, nullptr, 0, MAX_WAIT_MSECS, nullptr); - auto return_code = msg->data.result; - char* url; - curl_easy_getinfo(handle, CURLINFO_EFFECTIVE_URL, &url); - if (return_code != CURLE_OK) { - CTL_ERR("Download failed for: " << url << " - " - << curl_easy_strerror(return_code)); - continue; + auto result = ProcessCompletedTransfers(multi_handle); + if (result.has_error()) { + EmitTaskError(task.id); + { + std::lock_guard lock(event_emit_map_mutex); + event_emit_map_.erase(task.id); } + return cpp::fail(result.error()); + } - auto http_status_code = 0; - curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &http_status_code); - if (http_status_code == 200) { - CTL_INF("Download completed successfully for: " << url); - } else { - CTL_ERR("Download failed for: " << url << " - HTTP status code: " - << http_status_code); + if (IsTaskTerminated(task.id) || stop_flag_) { + CTL_INF("IsTaskTerminated " + std::to_string(IsTaskTerminated(task.id))); + CTL_INF("stop_flag_ " + std::to_string(stop_flag_)); + { + std::lock_guard lock(event_emit_map_mutex); + event_emit_map_.erase(task.id); } + CTL_INF("Emit task stopped: " << task.id); + EmitTaskStopped(task.id); + RemoveTaskFromStopList(task.id); + return cpp::fail("Task " + task.id + " cancelled"); + } + } while (still_running); + return {}; +} + +void DownloadService::SetUpCurlHandle(CURL* handle, const DownloadItem& item, + FILE* file, DownloadingData* dl_data) { + curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str()); + curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(handle, CURLOPT_WRITEDATA, file); + curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(handle, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(handle, CURLOPT_XFERINFOFUNCTION, ProgressCallback); + curl_easy_setopt(handle, CURLOPT_XFERINFODATA, dl_data); + + auto headers = curl_utils::GetHeaders(item.downloadUrl); + if (headers) { + curl_slist* curl_headers = nullptr; + for (const auto& [key, value] : headers.value()) { + curl_headers = + curl_slist_append(curl_headers, (key + ": " + value).c_str()); } + curl_easy_setopt(handle, CURLOPT_HTTPHEADER, curl_headers); } } cpp::result DownloadService::AddTask( DownloadTask& task, std::function callback) { - { + { // adding item to callback map std::lock_guard lock(callbacks_mutex_); callbacks_[task.id] = std::move(callback); } - { + { // adding task to queue std::lock_guard lock(queue_mutex_); task_queue_.push(task); CTL_INF("Task added to queue: " << task.id); } - queue_condition_.notify_one(); + task_cv_.notify_one(); return task; } @@ -382,6 +452,44 @@ void DownloadService::RemoveTaskFromStopList(const std::string& task_id) { tasks_to_stop_.erase(task_id); } +void DownloadService::EmitTaskStarted(const DownloadTask& task) { + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadStarted, + .download_task_ = task}); +} + +void DownloadService::EmitTaskStopped(const std::string& task_id) { + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadStopped, + .download_task_ = *it->second}); + } +} + +void DownloadService::EmitTaskError(const std::string& task_id) { + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadError, + .download_task_ = *it->second}); + } +} + +void DownloadService::EmitTaskCompleted(const std::string& task_id) { + std::lock_guard lock(active_tasks_mutex_); + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + for (auto& item : it->second->items) { + item.downloadedBytes = item.bytes; + } + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadSuccess, + .download_task_ = *it->second}); + } +} + void DownloadService::ExecuteCallback(const DownloadTask& task) { std::lock_guard lock(callbacks_mutex_); auto it = callbacks_.find(task.id); diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 3fa74a4c7..42073fc43 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -4,14 +4,65 @@ #include #include #include -#include #include #include +#include "common/download_task_queue.h" #include "common/event.h" -#include "utils/logging_utils.h" #include "utils/result.hpp" class DownloadService { + private: + static constexpr int MAX_CONCURRENT_TASKS = 4; + + struct DownloadingData { + std::string task_id; + std::string item_id; + DownloadService* download_service; + }; + + // Each worker represents a thread. Each worker will have its own multi_handle + struct WorkerData { + CURLM* multi_handle; + std::unordered_map> + downloading_data_map; + }; + std::vector> worker_data_; + + std::vector worker_threads_; + + // Store all download tasks in a queue + DownloadTaskQueue task_queue_; + + // Flag to stop Download service. Will stop all worker threads + std::atomic stop_flag_{false}; + + // Track active tasks across all workers + std::mutex active_tasks_mutex_; + std::unordered_map> active_tasks_; + + // sync primitives + std::condition_variable task_cv_; + std::mutex task_mutex_; + + void WorkerThread(int worker_id); + + void ProcessTask(DownloadTask& task, int worker_id); + + cpp::result ProcessMultiDownload( + DownloadTask& task, CURLM* multi_handle, + const std::vector>& handles); + + void SetUpCurlHandle(CURL* handle, const DownloadItem& item, FILE* file, + DownloadingData* dl_data); + + void EmitTaskStarted(const DownloadTask& task); + + void EmitTaskStopped(const std::string& task_id); + + void EmitTaskCompleted(const std::string& task_id); + + void EmitTaskError(const std::string& task_id); + public: using OnDownloadTaskSuccessfully = std::function; @@ -27,29 +78,13 @@ class DownloadService { explicit DownloadService(std::shared_ptr event_queue) : event_queue_{event_queue} { - curl_global_init(CURL_GLOBAL_ALL); - stop_flag_ = false; - multi_handle_ = curl_multi_init(); - worker_thread_ = std::thread(&DownloadService::WorkerThread, this); + InitializeWorkers(); }; - ~DownloadService() { - if (event_queue_ != nullptr) { - stop_flag_ = true; - queue_condition_.notify_one(); - - CTL_INF("DownloadService is being destroyed."); - curl_multi_cleanup(multi_handle_); - curl_global_cleanup(); - - if (worker_thread_.joinable()) { - worker_thread_.join(); - } - CTL_INF("DownloadService is destroyed."); - } - } + ~DownloadService() { Shutdown(); } DownloadService(const DownloadService&) = delete; + DownloadService& operator=(const DownloadService&) = delete; /** @@ -77,10 +112,9 @@ class DownloadService { cpp::result StopTask(const std::string& task_id); private: - struct DownloadingData { - std::string item_id; - DownloadService* download_service; - }; + void InitializeWorkers(); + + void Shutdown(); cpp::result Download( const std::string& download_id, @@ -90,10 +124,7 @@ class DownloadService { CURLM* multi_handle_; std::thread worker_thread_; - std::atomic stop_flag_; - // task queue - std::queue task_queue_; std::mutex queue_mutex_; std::condition_variable queue_condition_; @@ -106,16 +137,13 @@ class DownloadService { callbacks_; std::mutex callbacks_mutex_; - std::shared_ptr active_task_; - std::unordered_map> - downloading_data_map_; + std::unordered_map> + event_emit_map_; + std::mutex event_emit_map_mutex; void WorkerThread(); - void ProcessCompletedTransfers(); - - void ProcessTask(DownloadTask& task); - bool IsTaskTerminated(const std::string& task_id); void RemoveTaskFromStopList(const std::string& task_id); @@ -130,49 +158,69 @@ class DownloadService { if (downloading_data == nullptr) { return 0; } - const auto dl_item_id = downloading_data->item_id; - if (dltotal <= 0) { - return 0; - } auto dl_srv = downloading_data->download_service; - auto active_task = dl_srv->active_task_; - if (active_task == nullptr) { + if (dl_srv == nullptr) { return 0; } - for (auto& item : active_task->items) { - if (item.id == dl_item_id) { - item.downloadedBytes = dlnow; + // Lock during the update and event emission + std::lock_guard lock(dl_srv->active_tasks_mutex_); + + // Find and update the task + if (auto task_it = dl_srv->active_tasks_.find(downloading_data->task_id); + task_it != dl_srv->active_tasks_.end()) { + auto& task = task_it->second; + // Find the specific item in the task + for (auto& item : task->items) { + if (item.id != downloading_data->item_id) { + // not the item we are looking for + continue; + } + item.bytes = dltotal; - break; - } - } + item.downloadedBytes = dlnow; - auto all_items_bytes_greater_than_zero = - std::all_of(active_task->items.begin(), active_task->items.end(), - [](const DownloadItem& item) { return item.bytes > 0; }); - if (!all_items_bytes_greater_than_zero) { - return 0; - } + if (item.bytes == 0 || item.bytes == item.downloadedBytes) { + break; + } + + // Emit the event + { + std::lock_guard event_lock(dl_srv->event_emit_map_mutex); + // find the task id in event_emit_map + // if not found, add it to the map + auto should_emit_event{false}; + if (dl_srv->event_emit_map_.find(task->id) != + dl_srv->event_emit_map_.end()) { + auto last_event_time = dl_srv->event_emit_map_[task->id]; + auto current_time = std::chrono::steady_clock::now(); + + auto time_since_last_event = + std::chrono::duration_cast( + current_time - last_event_time) + .count(); + if (time_since_last_event >= 1000) { + // if the time since last event is more than 1 sec, emit the event + should_emit_event = true; + } + } else { + // if the task id is not found in the map, emit + should_emit_event = true; + } + + if (should_emit_event) { + dl_srv->event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, + .download_task_ = *task}); + dl_srv->event_emit_map_[task->id] = + std::chrono::steady_clock::now(); + } + } - // Check if one second has passed since the last event - static auto last_event_time = std::chrono::steady_clock::now(); - auto current_time = std::chrono::steady_clock::now(); - auto time_since_last_event = - std::chrono::duration_cast(current_time - - last_event_time) - .count(); - - // throttle event by 1 sec - if (time_since_last_event >= 1000) { - dl_srv->event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, - .download_task_ = *active_task}); - - // Update the last event time - last_event_time = current_time; + break; + } } return 0;