From a41001fafdc58ec704caa07365534d02d904a1ae Mon Sep 17 00:00:00 2001 From: James Date: Mon, 11 Nov 2024 22:13:17 +0700 Subject: [PATCH 1/7] feat: simultaneous download --- engine/services/download_service.cc | 337 +++++++++++++++++----------- engine/services/download_service.h | 171 ++++++++------ 2 files changed, 299 insertions(+), 209 deletions(-) diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index d125f8ef0..3f51dcad1 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -16,6 +16,34 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { size_t written = fwrite(ptr, size, nmemb, (FILE*)userdata); return written; } + +void ProcessCompletedTransfers(CURLM* multi_handle) { + CURLMsg* msg; + int msgs_left; + while ((msg = curl_multi_info_read(multi_handle, &msgs_left))) { + if (msg->msg == CURLMSG_DONE) { + auto handle = msg->easy_handle; + auto result = msg->data.result; + + char* url = nullptr; + curl_easy_getinfo(handle, CURLINFO_EFFECTIVE_URL, &url); + + if (result != CURLE_OK) { + CTL_ERR("Transfer failed for URL: " << url << " Error: " + << curl_easy_strerror(result)); + } else { + long response_code; + curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &response_code); + if (response_code == 200) { + CTL_INF("Transfer completed for URL: " << url); + } else { + CTL_ERR("Transfer failed with HTTP code: " << response_code + << " for URL: " << url); + } + } + } + } +} } // namespace cpp::result DownloadService::AddDownloadTask( @@ -179,195 +207,201 @@ cpp::result DownloadService::Download( curl_easy_cleanup(curl); return true; } -void DownloadService::WorkerThread() { + +cpp::result DownloadService::StopTask( + const std::string& task_id) { + // First try to cancel in queue + CTL_INF("Stopping task: " << task_id); + auto cancelled = task_queue_.cancelTask(task_id); + if (cancelled) { + EmitTaskStopped(task_id); + return task_id; + } + CTL_INF("Not found in pending task, try to find task " + task_id + + " in active tasks"); + // Check if task is currently being processed + std::lock_guard lock(active_tasks_mutex_); + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + CTL_INF("Found task " + task_id + " in active tasks"); + it->second->status = DownloadTask::Status::Cancelled; + EmitTaskStopped(task_id); + return task_id; + } + + CTL_WRN("Task not found"); + return cpp::fail("Task not found"); +} + +void DownloadService::InitializeWorkers() { + for (auto i = 0; i < MAX_CONCURRENT_TASKS; ++i) { + auto worker_data = std::make_unique(); + worker_data->multi_handle = curl_multi_init(); + worker_data_.push_back(std::move(worker_data)); + + worker_threads_.emplace_back([this, i]() { this->WorkerThread(i); }); + CTL_INF("Starting worker thread: " << i); + } +} + +void DownloadService::Shutdown() { + stop_flag_ = true; + task_cv_.notify_all(); // Wake up all waiting threads + + for (auto& thread : worker_threads_) { + if (thread.joinable()) { + thread.join(); + } + } + + for (auto& worker_data : worker_data_) { + curl_multi_cleanup(worker_data->multi_handle); + } + + // Clean up any remaining callbacks + std::lock_guard lock(callbacks_mutex_); + callbacks_.clear(); +} + +void DownloadService::WorkerThread(int worker_id) { + auto& worker_data = worker_data_[worker_id]; + while (!stop_flag_) { - DownloadTask task; + std::unique_lock lock(task_mutex_); + + // Wait for a task or stop signal + task_cv_.wait(lock, [this] { + auto pending_task = task_queue_.getNextPendingTask(); + return pending_task.has_value() || stop_flag_; + }); + + if (stop_flag_) { + break; + } + + auto maybe_task = task_queue_.pop(); + lock.unlock(); + + if (!maybe_task || maybe_task->status == DownloadTask::Status::Cancelled) { + continue; + } + + auto task = std::move(maybe_task.value()); + + // Register active task { - std::unique_lock lock(queue_mutex_); - queue_condition_.wait( - lock, [this] { return !task_queue_.empty() || stop_flag_; }); - if (stop_flag_) { - CTL_INF("Stopping download service.."); - break; - } - task = std::move(task_queue_.front()); - task_queue_.pop(); + std::lock_guard active_lock(active_tasks_mutex_); + active_tasks_[task.id] = std::make_shared(task); + } + + ProcessTask(task, worker_id); + + // Remove from active tasks + { + std::lock_guard active_lock(active_tasks_mutex_); + active_tasks_.erase(task.id); } - ProcessTask(task); } } -void DownloadService::ProcessTask(DownloadTask& task) { - CTL_INF("Processing task: " + task.id); +void DownloadService::ProcessTask(DownloadTask& task, int worker_id) { + auto& worker_data = worker_data_[worker_id]; std::vector> task_handles; - active_task_ = std::make_shared(task); - + task.status = DownloadTask::Status::InProgress; for (const auto& item : task.items) { auto handle = curl_easy_init(); - if (handle == nullptr) { - // skip the task + if (!handle) { CTL_ERR("Failed to init curl!"); return; } - auto file = fopen(item.localPath.string().c_str(), "wb"); if (!file) { CTL_ERR("Failed to open output file " + item.localPath.string()); + curl_easy_cleanup(handle); return; } - auto dl_data_ptr = std::make_shared(DownloadingData{ + .task_id = task.id, .item_id = item.id, .download_service = this, }); - downloading_data_map_.insert(std::make_pair(item.id, dl_data_ptr)); + worker_data->downloading_data_map[item.id] = dl_data_ptr; - auto headers = curl_utils::GetHeaders(item.downloadUrl); - if (headers.has_value()) { - curl_slist* curl_headers = nullptr; - - for (const auto& [key, value] : headers.value()) { - auto header = key + ": " + value; - curl_headers = curl_slist_append(curl_headers, header.c_str()); - } - - curl_easy_setopt(handle, CURLOPT_HTTPHEADER, curl_headers); - } - - curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str()); - curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback); - curl_easy_setopt(handle, CURLOPT_WRITEDATA, file); - curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(handle, CURLOPT_NOPROGRESS, 0L); - curl_easy_setopt(handle, CURLOPT_XFERINFOFUNCTION, ProgressCallback); - curl_easy_setopt(handle, CURLOPT_XFERINFODATA, dl_data_ptr.get()); - - curl_multi_add_handle(multi_handle_, handle); + SetUpCurlHandle(handle, item, file, dl_data_ptr.get()); + curl_multi_add_handle(worker_data->multi_handle, handle); task_handles.push_back(std::make_pair(handle, file)); - CTL_INF("Adding item to multi curl: " + item.ToString()); } - event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadStarted, - .download_task_ = task}); - - auto still_running = 0; - auto is_terminated = false; - do { - curl_multi_perform(multi_handle_, &still_running); - curl_multi_wait(multi_handle_, NULL, 0, MAX_WAIT_MSECS, NULL); - - if (IsTaskTerminated(task.id) || stop_flag_) { - CTL_INF("IsTaskTerminated " + std::to_string(IsTaskTerminated(task.id))); - CTL_INF("stop_flag_ " + std::to_string(stop_flag_)); - - is_terminated = true; - break; - } - } while (still_running); + EmitTaskStarted(task); - if (stop_flag_) { - CTL_INF("Download service is stopping.."); + ProcessMultiDownload(task, worker_data->multi_handle, task_handles); - // try to close file - for (auto pair : task_handles) { - fclose(pair.second); - } - - active_task_.reset(); - downloading_data_map_.clear(); - return; - } - - ProcessCompletedTransfers(); - for (const auto& pair : task_handles) { - curl_multi_remove_handle(multi_handle_, pair.first); - curl_easy_cleanup(pair.first); - fclose(pair.second); + // clean up + for (auto& [handle, file] : task_handles) { + curl_multi_remove_handle(worker_data->multi_handle, handle); + curl_easy_cleanup(handle); + fclose(file); } - downloading_data_map_.clear(); - auto copied_task = *active_task_; - active_task_.reset(); - RemoveTaskFromStopList(task.id); - - // if terminate by API calling and not from process stopping, we emit - // DownloadStopped event - if (is_terminated) { - event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadStopped, - .download_task_ = copied_task}); - } else { - CTL_INF("Executing callback.."); - ExecuteCallback(task); + ExecuteCallback(task); + EmitTaskCompleted(task.id); - // set all items to done - for (auto& item : copied_task.items) { - item.downloadedBytes = item.bytes; - } - - event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadSuccess, - .download_task_ = copied_task}); - } + worker_data->downloading_data_map.clear(); } -cpp::result DownloadService::StopTask( - const std::string& task_id) { - std::lock_guard lock(stop_mutex_); - - tasks_to_stop_.insert(task_id); - CTL_INF("Added task to stop list: " << task_id); - return task_id; -} +void DownloadService::ProcessMultiDownload( + DownloadTask& task, CURLM* multi_handle, + const std::vector>& handles) { + int still_running = 0; + do { + curl_multi_perform(multi_handle, &still_running); + curl_multi_wait(multi_handle, nullptr, 0, MAX_WAIT_MSECS, nullptr); -void DownloadService::ProcessCompletedTransfers() { - CURLMsg* msg; - int remaining_msg_count; + ProcessCompletedTransfers(multi_handle); - while ((msg = curl_multi_info_read(multi_handle_, &remaining_msg_count))) { - if (msg->msg == CURLMSG_DONE) { - auto handle = msg->easy_handle; + if (task.status == DownloadTask::Status::Cancelled || stop_flag_) { + EmitTaskStopped(task.id); + return; + } - auto return_code = msg->data.result; - char* url; - curl_easy_getinfo(handle, CURLINFO_EFFECTIVE_URL, &url); - if (return_code != CURLE_OK) { - CTL_ERR("Download failed for: " << url << " - " - << curl_easy_strerror(return_code)); - continue; - } + } while (still_running); +} - auto http_status_code = 0; - curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &http_status_code); - if (http_status_code == 200) { - CTL_INF("Download completed successfully for: " << url); - } else { - CTL_ERR("Download failed for: " << url << " - HTTP status code: " - << http_status_code); - } +void DownloadService::SetUpCurlHandle(CURL* handle, const DownloadItem& item, + FILE* file, DownloadingData* dl_data) { + curl_easy_setopt(handle, CURLOPT_URL, item.downloadUrl.c_str()); + curl_easy_setopt(handle, CURLOPT_WRITEFUNCTION, WriteCallback); + curl_easy_setopt(handle, CURLOPT_WRITEDATA, file); + curl_easy_setopt(handle, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(handle, CURLOPT_NOPROGRESS, 0L); + curl_easy_setopt(handle, CURLOPT_XFERINFOFUNCTION, ProgressCallback); + curl_easy_setopt(handle, CURLOPT_XFERINFODATA, dl_data); + + auto headers = curl_utils::GetHeaders(item.downloadUrl); + if (headers) { + curl_slist* curl_headers = nullptr; + for (const auto& [key, value] : headers.value()) { + curl_headers = + curl_slist_append(curl_headers, (key + ": " + value).c_str()); } + curl_easy_setopt(handle, CURLOPT_HTTPHEADER, curl_headers); } } cpp::result DownloadService::AddTask( DownloadTask& task, std::function callback) { - { + { // adding item to callback map std::lock_guard lock(callbacks_mutex_); callbacks_[task.id] = std::move(callback); } - { + { // adding task to queue std::lock_guard lock(queue_mutex_); task_queue_.push(task); CTL_INF("Task added to queue: " << task.id); } - queue_condition_.notify_one(); + task_cv_.notify_one(); return task; } @@ -382,6 +416,35 @@ void DownloadService::RemoveTaskFromStopList(const std::string& task_id) { tasks_to_stop_.erase(task_id); } +void DownloadService::EmitTaskStarted(const DownloadTask& task) { + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadStarted, + .download_task_ = task}); +} + +void DownloadService::EmitTaskStopped(const std::string& task_id) { + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadStopped, + .download_task_ = *it->second}); + } +} + +void DownloadService::EmitTaskCompleted(const std::string& task_id) { + std::lock_guard lock(active_tasks_mutex_); + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + for (auto& item : it->second->items) { + item.downloadedBytes = item.bytes; + } + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadSuccess, + .download_task_ = *it->second}); + } +} + void DownloadService::ExecuteCallback(const DownloadTask& task) { std::lock_guard lock(callbacks_mutex_); auto it = callbacks_.find(task.id); diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 3fa74a4c7..c4336ded5 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -4,14 +4,63 @@ #include #include #include -#include #include #include +#include "common/download_task_queue.h" #include "common/event.h" -#include "utils/logging_utils.h" #include "utils/result.hpp" class DownloadService { + private: + static constexpr int MAX_CONCURRENT_TASKS = 4; + + struct DownloadingData { + std::string task_id; + std::string item_id; + DownloadService* download_service; + }; + + // Each worker represents a thread. Each worker will have its own multi_handle + struct WorkerData { + CURLM* multi_handle; + std::unordered_map> + downloading_data_map; + }; + std::vector> worker_data_; + + std::vector worker_threads_; + + // Store all download tasks in a queue + DownloadTaskQueue task_queue_; + + // Flag to stop Download service. Will stop all worker threads + std::atomic stop_flag_{false}; + + // Track active tasks across all workers + std::mutex active_tasks_mutex_; + std::unordered_map> active_tasks_; + + // sync primitives + std::condition_variable task_cv_; + std::mutex task_mutex_; + + void WorkerThread(int worker_id); + + void ProcessTask(DownloadTask& task, int worker_id); + + void ProcessMultiDownload( + DownloadTask& task, CURLM* multi_handle, + const std::vector>& handles); + + void SetUpCurlHandle(CURL* handle, const DownloadItem& item, FILE* file, + DownloadingData* dl_data); + + void EmitTaskStarted(const DownloadTask& task); + + void EmitTaskStopped(const std::string& task_id); + + void EmitTaskCompleted(const std::string& task_id); + public: using OnDownloadTaskSuccessfully = std::function; @@ -27,29 +76,13 @@ class DownloadService { explicit DownloadService(std::shared_ptr event_queue) : event_queue_{event_queue} { - curl_global_init(CURL_GLOBAL_ALL); - stop_flag_ = false; - multi_handle_ = curl_multi_init(); - worker_thread_ = std::thread(&DownloadService::WorkerThread, this); + InitializeWorkers(); }; - ~DownloadService() { - if (event_queue_ != nullptr) { - stop_flag_ = true; - queue_condition_.notify_one(); - - CTL_INF("DownloadService is being destroyed."); - curl_multi_cleanup(multi_handle_); - curl_global_cleanup(); - - if (worker_thread_.joinable()) { - worker_thread_.join(); - } - CTL_INF("DownloadService is destroyed."); - } - } + ~DownloadService() { Shutdown(); } DownloadService(const DownloadService&) = delete; + DownloadService& operator=(const DownloadService&) = delete; /** @@ -77,10 +110,9 @@ class DownloadService { cpp::result StopTask(const std::string& task_id); private: - struct DownloadingData { - std::string item_id; - DownloadService* download_service; - }; + void InitializeWorkers(); + + void Shutdown(); cpp::result Download( const std::string& download_id, @@ -90,10 +122,7 @@ class DownloadService { CURLM* multi_handle_; std::thread worker_thread_; - std::atomic stop_flag_; - // task queue - std::queue task_queue_; std::mutex queue_mutex_; std::condition_variable queue_condition_; @@ -106,16 +135,8 @@ class DownloadService { callbacks_; std::mutex callbacks_mutex_; - std::shared_ptr active_task_; - std::unordered_map> - downloading_data_map_; - void WorkerThread(); - void ProcessCompletedTransfers(); - - void ProcessTask(DownloadTask& task); - bool IsTaskTerminated(const std::string& task_id); void RemoveTaskFromStopList(const std::string& task_id); @@ -130,49 +151,55 @@ class DownloadService { if (downloading_data == nullptr) { return 0; } - const auto dl_item_id = downloading_data->item_id; - if (dltotal <= 0) { - return 0; - } - auto dl_srv = downloading_data->download_service; - auto active_task = dl_srv->active_task_; - if (active_task == nullptr) { + if (downloading_data->download_service == nullptr) { return 0; } - for (auto& item : active_task->items) { - if (item.id == dl_item_id) { - item.downloadedBytes = dlnow; + // Lock during the update and event emission + std::lock_guard lock( + downloading_data->download_service->active_tasks_mutex_); + + // Find and update the task + if (auto task_it = downloading_data->download_service->active_tasks_.find( + downloading_data->task_id); + task_it != downloading_data->download_service->active_tasks_.end()) { + auto& task = task_it->second; + // Find the specific item in the task + for (auto& item : task->items) { + if (item.id != downloading_data->item_id) { + // not the item we are looking for + continue; + } + + if (dltotal == 0) { + // if dltotal is 0, we prevent to send the event + break; + } + item.bytes = dltotal; - break; - } - } + item.downloadedBytes = dlnow; - auto all_items_bytes_greater_than_zero = - std::all_of(active_task->items.begin(), active_task->items.end(), - [](const DownloadItem& item) { return item.bytes > 0; }); - if (!all_items_bytes_greater_than_zero) { - return 0; - } + static auto last_event_time = std::chrono::steady_clock::now(); + auto current_time = std::chrono::steady_clock::now(); + auto time_since_last_event = + std::chrono::duration_cast( + current_time - last_event_time) + .count(); + + // throttle event by 1 sec + if (time_since_last_event >= 1000) { + downloading_data->download_service->event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, + .download_task_ = *task}); - // Check if one second has passed since the last event - static auto last_event_time = std::chrono::steady_clock::now(); - auto current_time = std::chrono::steady_clock::now(); - auto time_since_last_event = - std::chrono::duration_cast(current_time - - last_event_time) - .count(); - - // throttle event by 1 sec - if (time_since_last_event >= 1000) { - dl_srv->event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, - .download_task_ = *active_task}); - - // Update the last event time - last_event_time = current_time; + // Update the last event time + last_event_time = current_time; + } + + break; + } } return 0; From 6a7d2883335a3954640de6abab4ee2045a7a933b Mon Sep 17 00:00:00 2001 From: James Date: Tue, 12 Nov 2024 15:02:53 +0700 Subject: [PATCH 2/7] handle error download --- engine/services/download_service.cc | 41 +++++++++++++++++++++++------ engine/services/download_service.h | 4 ++- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index 3f51dcad1..9ce418b09 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -17,7 +17,7 @@ size_t WriteCallback(char* ptr, size_t size, size_t nmemb, void* userdata) { return written; } -void ProcessCompletedTransfers(CURLM* multi_handle) { +cpp::result ProcessCompletedTransfers(CURLM* multi_handle) { CURLMsg* msg; int msgs_left; while ((msg = curl_multi_info_read(multi_handle, &msgs_left))) { @@ -31,6 +31,9 @@ void ProcessCompletedTransfers(CURLM* multi_handle) { if (result != CURLE_OK) { CTL_ERR("Transfer failed for URL: " << url << " Error: " << curl_easy_strerror(result)); + // download failed + return cpp::fail("Transfer failed for URL: " + std::string(url) + + " Error: " + curl_easy_strerror(result)); } else { long response_code; curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &response_code); @@ -39,10 +42,15 @@ void ProcessCompletedTransfers(CURLM* multi_handle) { } else { CTL_ERR("Transfer failed with HTTP code: " << response_code << " for URL: " << url); + // download failed + return cpp::fail("Transfer failed with HTTP code: " + + std::to_string(response_code) + + " for URL: " + std::string(url)); } } } } + return {}; } } // namespace @@ -334,7 +342,8 @@ void DownloadService::ProcessTask(DownloadTask& task, int worker_id) { EmitTaskStarted(task); - ProcessMultiDownload(task, worker_data->multi_handle, task_handles); + auto result = + ProcessMultiDownload(task, worker_data->multi_handle, task_handles); // clean up for (auto& [handle, file] : task_handles) { @@ -343,13 +352,16 @@ void DownloadService::ProcessTask(DownloadTask& task, int worker_id) { fclose(file); } - ExecuteCallback(task); - EmitTaskCompleted(task.id); + if (!result.has_error()) { + // if the download has error, we are not run the callback + ExecuteCallback(task); + EmitTaskCompleted(task.id); + } worker_data->downloading_data_map.clear(); } -void DownloadService::ProcessMultiDownload( +cpp::result DownloadService::ProcessMultiDownload( DownloadTask& task, CURLM* multi_handle, const std::vector>& handles) { int still_running = 0; @@ -357,14 +369,18 @@ void DownloadService::ProcessMultiDownload( curl_multi_perform(multi_handle, &still_running); curl_multi_wait(multi_handle, nullptr, 0, MAX_WAIT_MSECS, nullptr); - ProcessCompletedTransfers(multi_handle); + auto result = ProcessCompletedTransfers(multi_handle); + if (result.has_error()) { + EmitTaskError(task.id); + return cpp::fail(result.error()); + } if (task.status == DownloadTask::Status::Cancelled || stop_flag_) { EmitTaskStopped(task.id); - return; + return cpp::fail("Task " + task.id + " cancelled"); } - } while (still_running); + return {}; } void DownloadService::SetUpCurlHandle(CURL* handle, const DownloadItem& item, @@ -432,6 +448,15 @@ void DownloadService::EmitTaskStopped(const std::string& task_id) { } } +void DownloadService::EmitTaskError(const std::string& task_id) { + if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { + event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadError, + .download_task_ = *it->second}); + } +} + void DownloadService::EmitTaskCompleted(const std::string& task_id) { std::lock_guard lock(active_tasks_mutex_); if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { diff --git a/engine/services/download_service.h b/engine/services/download_service.h index c4336ded5..453e2fe2e 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -48,7 +48,7 @@ class DownloadService { void ProcessTask(DownloadTask& task, int worker_id); - void ProcessMultiDownload( + cpp::result ProcessMultiDownload( DownloadTask& task, CURLM* multi_handle, const std::vector>& handles); @@ -61,6 +61,8 @@ class DownloadService { void EmitTaskCompleted(const std::string& task_id); + void EmitTaskError(const std::string& task_id); + public: using OnDownloadTaskSuccessfully = std::function; From 8d7c5fa827538cca8a9a1a269f9986cdcfb31089 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 13 Nov 2024 13:35:42 +0700 Subject: [PATCH 3/7] update --- engine/services/download_service.cc | 12 ++++++ engine/services/download_service.h | 64 +++++++++++++++++++---------- 2 files changed, 54 insertions(+), 22 deletions(-) diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index 9ce418b09..f9a866d84 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -356,6 +356,10 @@ void DownloadService::ProcessTask(DownloadTask& task, int worker_id) { // if the download has error, we are not run the callback ExecuteCallback(task); EmitTaskCompleted(task.id); + { + std::lock_guard lock(event_emit_map_mutex); + event_emit_map_.erase(task.id); + } } worker_data->downloading_data_map.clear(); @@ -372,11 +376,19 @@ cpp::result DownloadService::ProcessMultiDownload( auto result = ProcessCompletedTransfers(multi_handle); if (result.has_error()) { EmitTaskError(task.id); + { + std::lock_guard lock(event_emit_map_mutex); + event_emit_map_.erase(task.id); + } return cpp::fail(result.error()); } if (task.status == DownloadTask::Status::Cancelled || stop_flag_) { EmitTaskStopped(task.id); + { + std::lock_guard lock(event_emit_map_mutex); + event_emit_map_.erase(task.id); + } return cpp::fail("Task " + task.id + " cancelled"); } } while (still_running); diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 453e2fe2e..5f7f57c9d 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -137,6 +137,11 @@ class DownloadService { callbacks_; std::mutex callbacks_mutex_; + std::unordered_map> + event_emit_map_; + std::mutex event_emit_map_mutex; + void WorkerThread(); bool IsTaskTerminated(const std::string& task_id); @@ -154,18 +159,17 @@ class DownloadService { return 0; } - if (downloading_data->download_service == nullptr) { + auto dl_srv = downloading_data->download_service; + if (dl_srv == nullptr) { return 0; } // Lock during the update and event emission - std::lock_guard lock( - downloading_data->download_service->active_tasks_mutex_); + std::lock_guard lock(dl_srv->active_tasks_mutex_); // Find and update the task - if (auto task_it = downloading_data->download_service->active_tasks_.find( - downloading_data->task_id); - task_it != downloading_data->download_service->active_tasks_.end()) { + if (auto task_it = dl_srv->active_tasks_.find(downloading_data->task_id); + task_it != dl_srv->active_tasks_.end()) { auto& task = task_it->second; // Find the specific item in the task for (auto& item : task->items) { @@ -182,22 +186,38 @@ class DownloadService { item.bytes = dltotal; item.downloadedBytes = dlnow; - static auto last_event_time = std::chrono::steady_clock::now(); - auto current_time = std::chrono::steady_clock::now(); - auto time_since_last_event = - std::chrono::duration_cast( - current_time - last_event_time) - .count(); - - // throttle event by 1 sec - if (time_since_last_event >= 1000) { - downloading_data->download_service->event_queue_->enqueue( - EventType::DownloadEvent, - DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, - .download_task_ = *task}); - - // Update the last event time - last_event_time = current_time; + // Emit the event + { + std::lock_guard event_lock(dl_srv->event_emit_map_mutex); + // find the task id in event_emit_map + // if not found, add it to the map + auto should_emit_event{false}; + if (dl_srv->event_emit_map_.find(task->id) != + dl_srv->event_emit_map_.end()) { + auto last_event_time = dl_srv->event_emit_map_[task->id]; + auto current_time = std::chrono::steady_clock::now(); + + auto time_since_last_event = + std::chrono::duration_cast( + current_time - last_event_time) + .count(); + if (time_since_last_event >= 1000) { + // if the time since last event is more than 1 sec, emit the event + should_emit_event = true; + } + } else { + // if the task id is not found in the map, emit + should_emit_event = true; + } + + if (should_emit_event) { + dl_srv->event_queue_->enqueue( + EventType::DownloadEvent, + DownloadEvent{.type_ = DownloadEventType::DownloadUpdated, + .download_task_ = *task}); + dl_srv->event_emit_map_[task->id] = + std::chrono::steady_clock::now(); + } } break; From 9ae7e39da851b25fee7e5d42a333ebefa591b0d0 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 14 Nov 2024 18:19:58 +0700 Subject: [PATCH 4/7] fix: download progress --- engine/cli/commands/engine_install_cmd.cc | 31 +++++---- engine/cli/commands/engine_update_cmd.cc | 15 ++--- engine/cli/commands/model_pull_cmd.cc | 2 +- engine/cli/utils/download_progress.cc | 81 ++++++++++++++--------- engine/cli/utils/download_progress.h | 17 +++-- 5 files changed, 84 insertions(+), 62 deletions(-) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 2424513eb..1f712d10c 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -36,8 +36,13 @@ bool EngineInstallCmd::Exec(const std::string& engine, DownloadProgress dp; dp.Connect(host_, port_); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, [&dp, &engine] { - return dp.Handle(DownloadType::Engine); + auto dp_res = std::async(std::launch::deferred, [&dp] { + bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (need_cuda_download) { + return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + } else { + return dp.Handle({DownloadType::Engine}); + } }); auto versions_url = url_parser::Url{ @@ -133,12 +138,6 @@ bool EngineInstallCmd::Exec(const std::string& engine, if (!dp_res.get()) return false; - bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); - if (check_cuda_download) { - if (!dp.Handle(DownloadType::CudaToolkit)) - return false; - } - CLI_LOG("Engine " << engine << " downloaded successfully!") return true; } @@ -147,8 +146,14 @@ bool EngineInstallCmd::Exec(const std::string& engine, DownloadProgress dp; dp.Connect(host_, port_); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, - [&dp] { return dp.Handle(DownloadType::Engine); }); + auto dp_res = std::async(std::launch::deferred, [&dp] { + bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (need_cuda_download) { + return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + } else { + return dp.Handle({DownloadType::Engine}); + } + }); auto install_url = url_parser::Url{ .protocol = "http", @@ -183,12 +188,6 @@ bool EngineInstallCmd::Exec(const std::string& engine, if (!dp_res.get()) return false; - bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); - if (check_cuda_download) { - if (!dp.Handle(DownloadType::CudaToolkit)) - return false; - } - CLI_LOG("Engine " << engine << " downloaded successfully!") return true; } diff --git a/engine/cli/commands/engine_update_cmd.cc b/engine/cli/commands/engine_update_cmd.cc index ee3a526c3..9717ddb15 100644 --- a/engine/cli/commands/engine_update_cmd.cc +++ b/engine/cli/commands/engine_update_cmd.cc @@ -24,8 +24,13 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port, DownloadProgress dp; dp.Connect(host, port); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, [&dp, &engine] { - return dp.Handle(DownloadType::Engine); + auto dp_res = std::async(std::launch::deferred, [&dp] { + bool need_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (need_cuda_download) { + return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit}); + } else { + return dp.Handle({DownloadType::Engine}); + } }); auto update_url = url_parser::Url{ @@ -48,12 +53,6 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port, if (!dp_res.get()) return false; - bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); - if (check_cuda_download) { - if (!dp.Handle(DownloadType::CudaToolkit)) - return false; - } - CLI_LOG("Engine " << engine << " updated successfully!") return true; } diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 5793c2e09..d769b667a 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -143,7 +143,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, reinterpret_cast(console_ctrl_handler), true); #endif dp.Connect(host, port); - if (!dp.Handle(DownloadType::Model)) + if (!dp.Handle({DownloadType::Model})) return std::nullopt; if (force_stop) return std::nullopt; diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index 017d71d0e..e085a660e 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -34,7 +34,8 @@ bool DownloadProgress::Connect(const std::string& host, int port) { return true; } -bool DownloadProgress::Handle(const DownloadType& event_type) { +bool DownloadProgress::Handle( + const std::unordered_set& event_type) { assert(!!ws_); #if defined(_WIN32) HANDLE h_out = GetStdHandle(STD_OUTPUT_HANDLE); @@ -50,10 +51,14 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { } } #endif - status_ = DownloadStatus::DownloadStarted; + for (auto et : event_type) { + status_[et] = DownloadStatus::DownloadStarted; + } std::unique_ptr> bars; - std::vector> items; + std::unordered_map>> + items; indicators::show_console_cursor(false); auto start = std::chrono::steady_clock::now(); auto handle_message = [this, &bars, &items, event_type, @@ -78,22 +83,28 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { auto ev = cortex::event::GetDownloadEventFromJson( json_helper::ParseJsonString(message)); // Ignore other task type - if (ev.download_task_.type != event_type) { + if (event_type.find(ev.download_task_.type) == event_type.end()) { return; } auto now = std::chrono::steady_clock::now(); if (!bars) { bars = std::make_unique< indicators::DynamicProgress>(); - for (auto& i : ev.download_task_.items) { - items.emplace_back(std::make_unique( - indicators::option::BarWidth{50}, indicators::option::Start{"["}, - indicators::option::Fill{"="}, indicators::option::Lead{">"}, - indicators::option::End{"]"}, - indicators::option::PrefixText{pad_string(Repo2Engine(i.id))}, - indicators::option::ForegroundColor{indicators::Color::white}, - indicators::option::ShowRemainingTime{false})); - bars->push_back(*(items.back())); + } + for (auto& i : ev.download_task_.items) { + if (items.find(i.id) == items.end()) { + auto idx = items.size(); + items[i.id] = std::pair( + idx, + std::make_unique( + indicators::option::BarWidth{50}, + indicators::option::Start{"["}, indicators::option::Fill{"="}, + indicators::option::Lead{">"}, indicators::option::End{"]"}, + indicators::option::PrefixText{pad_string(Repo2Engine(i.id))}, + indicators::option::ForegroundColor{indicators::Color::white}, + indicators::option::ShowRemainingTime{false})); + + bars->push_back(*(items.at(i.id).second)); } } for (int i = 0; i < ev.download_task_.items.size(); i++) { @@ -113,32 +124,36 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { (total - downloaded) / bytes_per_sec); } - (*bars)[i].set_option(indicators::option::PrefixText{ - pad_string(Repo2Engine(it.id)) + - std::to_string(int(static_cast(downloaded) / total * 100)) + - '%'}); - (*bars)[i].set_progress( + (*bars)[items.at(it.id).first].set_option( + indicators::option::PrefixText{ + pad_string(Repo2Engine(it.id)) + + std::to_string( + int(static_cast(downloaded) / total * 100)) + + '%'}); + (*bars)[items.at(it.id).first].set_progress( int(static_cast(downloaded) / total * 100)); - (*bars)[i].set_option(indicators::option::PostfixText{ - time_remaining + " " + - format_utils::BytesToHumanReadable(downloaded) + "/" + - format_utils::BytesToHumanReadable(total)}); + (*bars)[items.at(it.id).first].set_option( + indicators::option::PostfixText{ + time_remaining + " " + + format_utils::BytesToHumanReadable(downloaded) + "/" + + format_utils::BytesToHumanReadable(total)}); } else if (ev.type_ == DownloadStatus::DownloadSuccess) { uint64_t total = it.bytes.value_or(std::numeric_limits::max()); - (*bars)[i].set_progress(100); + (*bars)[items.at(it.id).first].set_progress(100); auto total_str = format_utils::BytesToHumanReadable(total); - (*bars)[i].set_option(indicators::option::PostfixText{ - "00m:00s " + total_str + "/" + total_str}); - (*bars)[i].set_option(indicators::option::PrefixText{ - pad_string(Repo2Engine(it.id)) + "100%"}); - (*bars)[i].set_progress(100); + (*bars)[items.at(it.id).first].set_option( + indicators::option::PostfixText{"00m:00s " + total_str + "/" + + total_str}); + (*bars)[items.at(it.id).first].set_option( + indicators::option::PrefixText{pad_string(Repo2Engine(it.id)) + + "100%"}); + (*bars)[items.at(it.id).first].set_progress(100); CTL_INF("Download success"); } + status_[ev.download_task_.type] = ev.type_; } - - status_ = ev.type_; }; while (ws_->getReadyState() != easywsclient::WebSocket::CLOSED && @@ -152,7 +167,9 @@ bool DownloadProgress::Handle(const DownloadType& event_type) { SetConsoleMode(h_out, dw_original_out_mode); } #endif - if (status_ == DownloadStatus::DownloadError) - return false; + for (auto const& [_, v] : status_) { + if (v == DownloadStatus::DownloadError) + return false; + } return true; } diff --git a/engine/cli/utils/download_progress.h b/engine/cli/utils/download_progress.h index 98fe85654..6ea764ec4 100644 --- a/engine/cli/utils/download_progress.h +++ b/engine/cli/utils/download_progress.h @@ -2,6 +2,7 @@ #include #include #include +#include #include "common/event.h" #include "easywsclient.hpp" @@ -10,19 +11,25 @@ class DownloadProgress { public: bool Connect(const std::string& host, int port); - bool Handle(const DownloadType& event_type); + bool Handle(const std::unordered_set& event_type); void ForceStop() { force_stop_ = true; } private: bool should_stop() const { - return (status_ != DownloadStatus::DownloadStarted && - status_ != DownloadStatus::DownloadUpdated) || - force_stop_; + bool should_stop = true; + for (auto const& [_, v] : status_) { + should_stop &= (v == DownloadStatus::DownloadSuccess); + } + for (auto const& [_, v] : status_) { + should_stop |= (v == DownloadStatus::DownloadError || + v == DownloadStatus::DownloadStopped); + } + return should_stop || force_stop_; } private: std::unique_ptr ws_; - std::atomic status_ = DownloadStatus::DownloadStarted; + std::unordered_map> status_; std::atomic force_stop_ = false; }; \ No newline at end of file From 461601b1fff1e84d0e326629ebb946574df224c9 Mon Sep 17 00:00:00 2001 From: James Date: Thu, 14 Nov 2024 21:24:26 +0700 Subject: [PATCH 5/7] update --- engine/services/download_service.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 5f7f57c9d..42073fc43 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -178,14 +178,13 @@ class DownloadService { continue; } - if (dltotal == 0) { - // if dltotal is 0, we prevent to send the event - break; - } - item.bytes = dltotal; item.downloadedBytes = dlnow; + if (item.bytes == 0 || item.bytes == item.downloadedBytes) { + break; + } + // Emit the event { std::lock_guard event_lock(dl_srv->event_emit_map_mutex); From c01df03ed255913c94dd1c9413f328fcfb7631ed Mon Sep 17 00:00:00 2001 From: James Date: Thu, 14 Nov 2024 22:06:10 +0700 Subject: [PATCH 6/7] fix ci --- engine/e2e-test/test_api_model_start.py | 6 ++++-- engine/e2e-test/test_api_model_stop.py | 10 ++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/engine/e2e-test/test_api_model_start.py b/engine/e2e-test/test_api_model_start.py index 830d32da8..d6a98a78b 100644 --- a/engine/e2e-test/test_api_model_start.py +++ b/engine/e2e-test/test_api_model_start.py @@ -12,7 +12,7 @@ def setup_and_teardown(self): success = start_server() if not success: raise Exception("Failed to start server") - requests.post("http://localhost:3928/v1/engines/llama-cpp") + run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60) run("Delete model", ["models", "delete", "tinyllama:gguf"]) run( "Pull model", @@ -27,5 +27,7 @@ def setup_and_teardown(self): def test_models_start_should_be_successful(self): json_body = {"model": "tinyllama:gguf"} - response = requests.post("http://localhost:3928/v1/models/start", json=json_body) + response = requests.post( + "http://localhost:3928/v1/models/start", json=json_body + ) assert response.status_code == 200, f"status_code: {response.status_code}" diff --git a/engine/e2e-test/test_api_model_stop.py b/engine/e2e-test/test_api_model_stop.py index 97bec671e..dc3b6b77b 100644 --- a/engine/e2e-test/test_api_model_stop.py +++ b/engine/e2e-test/test_api_model_stop.py @@ -1,6 +1,6 @@ import pytest import requests -from test_runner import start_server, stop_server +from test_runner import run, start_server, stop_server class TestApiModelStop: @@ -13,16 +13,18 @@ def setup_and_teardown(self): if not success: raise Exception("Failed to start server") - requests.post("http://localhost:3928/engines/llama-cpp") + run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60) yield - requests.delete("http://localhost:3928/engines/llama-cpp") + run("Uninstall engine", ["engines", "uninstall", "llama-cpp"]) # Teardown stop_server() def test_models_stop_should_be_successful(self): json_body = {"model": "tinyllama:gguf"} - response = requests.post("http://localhost:3928/v1/models/start", json=json_body) + response = requests.post( + "http://localhost:3928/v1/models/start", json=json_body + ) assert response.status_code == 200, f"status_code: {response.status_code}" response = requests.post("http://localhost:3928/v1/models/stop", json=json_body) assert response.status_code == 200, f"status_code: {response.status_code}" From 4c110bfbc3443f99e985ed09ed9c49b36aeb7ed6 Mon Sep 17 00:00:00 2001 From: James Date: Thu, 14 Nov 2024 23:39:25 +0700 Subject: [PATCH 7/7] fix: abort download --- engine/services/download_service.cc | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index f9a866d84..08e366151 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -232,6 +232,10 @@ cpp::result DownloadService::StopTask( if (auto it = active_tasks_.find(task_id); it != active_tasks_.end()) { CTL_INF("Found task " + task_id + " in active tasks"); it->second->status = DownloadTask::Status::Cancelled; + { + std::lock_guard lock(stop_mutex_); + tasks_to_stop_.insert(task_id); + } EmitTaskStopped(task_id); return task_id; } @@ -368,7 +372,7 @@ void DownloadService::ProcessTask(DownloadTask& task, int worker_id) { cpp::result DownloadService::ProcessMultiDownload( DownloadTask& task, CURLM* multi_handle, const std::vector>& handles) { - int still_running = 0; + auto still_running = 0; do { curl_multi_perform(multi_handle, &still_running); curl_multi_wait(multi_handle, nullptr, 0, MAX_WAIT_MSECS, nullptr); @@ -383,12 +387,16 @@ cpp::result DownloadService::ProcessMultiDownload( return cpp::fail(result.error()); } - if (task.status == DownloadTask::Status::Cancelled || stop_flag_) { - EmitTaskStopped(task.id); + if (IsTaskTerminated(task.id) || stop_flag_) { + CTL_INF("IsTaskTerminated " + std::to_string(IsTaskTerminated(task.id))); + CTL_INF("stop_flag_ " + std::to_string(stop_flag_)); { std::lock_guard lock(event_emit_map_mutex); event_emit_map_.erase(task.id); } + CTL_INF("Emit task stopped: " << task.id); + EmitTaskStopped(task.id); + RemoveTaskFromStopList(task.id); return cpp::fail("Task " + task.id + " cancelled"); } } while (still_running);