diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 42fd78ae0..a0d008c60 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -1,4 +1,5 @@ #include "engine_install_cmd.h" +#include #include "server_start_cmd.h" #include "utils/download_progress.h" #include "utils/engine_constants.h" @@ -31,9 +32,16 @@ bool EngineInstallCmd::Exec(const std::string& engine, } } + DownloadProgress dp; + dp.Connect(host_, port_); + // engine can be small, so need to start ws first + auto dp_res = std::async(std::launch::deferred, + [&dp, &engine] { return dp.Handle(engine); }); + CLI_LOG("Validating download items, please wait..") + httplib::Client cli(host_ + ":" + std::to_string(port_)); Json::Value json_data; - json_data["version"] = version.empty() ? "latest" : version; + json_data["version"] = version.empty() ? "latest" : version; auto data_str = json_data.toStyledString(); cli.set_read_timeout(std::chrono::seconds(60)); auto res = cli.Post("/v1/engines/install/" + engine, httplib::Headers(), @@ -43,18 +51,19 @@ bool EngineInstallCmd::Exec(const std::string& engine, if (res->status != httplib::StatusCode::OK_200) { auto root = json_helper::ParseJsonString(res->body); CLI_LOG(root["message"].asString()); + dp.ForceStop(); return false; + } else { + CLI_LOG("Start downloading.."); } } else { auto err = res.error(); CTL_ERR("HTTP error: " << httplib::to_string(err)); + dp.ForceStop(); return false; } - CLI_LOG("Start downloading ...") - DownloadProgress dp; - dp.Connect(host_, port_); - if (!dp.Handle(engine)) + if (!dp_res.get()) return false; bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 566d1e755..edcd84d63 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -122,7 +122,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, return std::nullopt; } - CLI_LOG("Start downloading ...") + CLI_LOG("Start downloading..") DownloadProgress dp; bool force_stop = false; diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index 4f40d47cb..ba1aa5a5c 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -23,12 +23,14 @@ bool DownloadProgress::Connect(const std::string& host, int port) { bool DownloadProgress::Handle(const std::string& id) { assert(!!ws_); + uint64_t total = std::numeric_limits::max(); status_ = DownloadStatus::DownloadStarted; std::unique_ptr> bars; std::vector> items; indicators::show_console_cursor(false); - auto handle_message = [this, &bars, &items, id](const std::string& message) { + auto handle_message = [this, &bars, &items, &total, + id](const std::string& message) { CTL_INF(message); auto pad_string = [](const std::string& str, @@ -70,7 +72,10 @@ bool DownloadProgress::Handle(const std::string& id) { for (int i = 0; i < ev.download_task_.items.size(); i++) { auto& it = ev.download_task_.items[i]; uint64_t downloaded = it.downloadedBytes.value_or(0); - uint64_t total = it.bytes.value_or(std::numeric_limits::max()); + if (total == 0 || total == std::numeric_limits::max()) { + total = it.bytes.value_or(std::numeric_limits::max()); + CTL_INF("Updated - total: " << total); + } if (ev.type_ == DownloadStatus::DownloadUpdated) { (*bars)[i].set_option(indicators::option::PrefixText{ pad_string(it.id) +