Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit 166cdb5

Browse files
fix: use download event type to listen ws on client side (#1601)
* fix: use download event type to listen ws on client side * fix: format * fix: remove unused --------- Co-authored-by: vansangpfiev <sang@jan.ai>
1 parent 601437d commit 166cdb5

File tree

4 files changed

+27
-13
lines changed

4 files changed

+27
-13
lines changed

engine/cli/commands/engine_install_cmd.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ bool EngineInstallCmd::Exec(const std::string& engine,
3535
DownloadProgress dp;
3636
dp.Connect(host_, port_);
3737
// engine can be small, so need to start ws first
38-
auto dp_res = std::async(std::launch::deferred,
39-
[&dp, &engine] { return dp.Handle(engine); });
38+
auto dp_res = std::async(std::launch::deferred, [&dp] {
39+
return dp.Handle(DownloadType::Engine);
40+
});
4041
CLI_LOG("Validating download items, please wait..")
4142

4243
httplib::Client cli(host_ + ":" + std::to_string(port_));
@@ -68,7 +69,7 @@ bool EngineInstallCmd::Exec(const std::string& engine,
6869

6970
bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
7071
if (check_cuda_download) {
71-
if (!dp.Handle("cuda"))
72+
if (!dp.Handle(DownloadType::CudaToolkit))
7273
return false;
7374
}
7475

engine/cli/commands/model_pull_cmd.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ std::optional<std::string> ModelPullCmd::Exec(const std::string& host, int port,
149149
reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
150150
#endif
151151
dp.Connect(host, port);
152-
if (!dp.Handle(model_id))
152+
if (!dp.Handle(DownloadType::Model))
153153
return std::nullopt;
154154
if (force_stop)
155155
return std::nullopt;

engine/cli/utils/download_progress.cc

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,23 @@
44
#include "common/event.h"
55
#include "indicators/dynamic_progress.hpp"
66
#include "indicators/progress_bar.hpp"
7+
#include "utils/engine_constants.h"
78
#include "utils/format_utils.h"
89
#include "utils/json_helper.h"
910
#include "utils/logging_utils.h"
1011

12+
namespace {
13+
std::string Repo2Engine(const std::string& r) {
14+
if (r == kLlamaRepo) {
15+
return kLlamaEngine;
16+
} else if (r == kOnnxRepo) {
17+
return kOnnxEngine;
18+
} else if (r == kTrtLlmRepo) {
19+
return kTrtLlmEngine;
20+
}
21+
return r;
22+
};
23+
} // namespace
1124
bool DownloadProgress::Connect(const std::string& host, int port) {
1225
if (ws_) {
1326
CTL_INF("Already connected!");
@@ -21,7 +34,7 @@ bool DownloadProgress::Connect(const std::string& host, int port) {
2134
return true;
2235
}
2336

24-
bool DownloadProgress::Handle(const std::string& id) {
37+
bool DownloadProgress::Handle(const DownloadType& event_type) {
2538
assert(!!ws_);
2639
std::unordered_map<std::string, uint64_t> totals;
2740
status_ = DownloadStatus::DownloadStarted;
@@ -30,7 +43,7 @@ bool DownloadProgress::Handle(const std::string& id) {
3043
std::vector<std::unique_ptr<indicators::ProgressBar>> items;
3144
indicators::show_console_cursor(false);
3245
auto handle_message = [this, &bars, &items, &totals,
33-
id](const std::string& message) {
46+
event_type](const std::string& message) {
3447
CTL_INF(message);
3548

3649
auto pad_string = [](const std::string& str,
@@ -50,8 +63,8 @@ bool DownloadProgress::Handle(const std::string& id) {
5063

5164
auto ev = cortex::event::GetDownloadEventFromJson(
5265
json_helper::ParseJsonString(message));
53-
// Ignore other task ids
54-
if (ev.download_task_.id != id) {
66+
// Ignore other task type
67+
if (ev.download_task_.type != event_type) {
5568
return;
5669
}
5770

@@ -63,7 +76,7 @@ bool DownloadProgress::Handle(const std::string& id) {
6376
indicators::option::BarWidth{50}, indicators::option::Start{"["},
6477
indicators::option::Fill{"="}, indicators::option::Lead{">"},
6578
indicators::option::End{"]"},
66-
indicators::option::PrefixText{pad_string(i.id)},
79+
indicators::option::PrefixText{pad_string(Repo2Engine(i.id))},
6780
indicators::option::ForegroundColor{indicators::Color::white},
6881
indicators::option::ShowRemainingTime{true}));
6982
bars->push_back(*(items.back()));
@@ -80,7 +93,7 @@ bool DownloadProgress::Handle(const std::string& id) {
8093
if (ev.type_ == DownloadStatus::DownloadStarted ||
8194
ev.type_ == DownloadStatus::DownloadUpdated) {
8295
(*bars)[i].set_option(indicators::option::PrefixText{
83-
pad_string(it.id) +
96+
pad_string(Repo2Engine(it.id)) +
8497
std::to_string(
8598
int(static_cast<double>(downloaded) / totals[it.id] * 100)) +
8699
'%'});
@@ -94,8 +107,8 @@ bool DownloadProgress::Handle(const std::string& id) {
94107
auto total_str = format_utils::BytesToHumanReadable(totals[it.id]);
95108
(*bars)[i].set_option(
96109
indicators::option::PostfixText{total_str + "/" + total_str});
97-
(*bars)[i].set_option(
98-
indicators::option::PrefixText{pad_string(it.id) + "100%"});
110+
(*bars)[i].set_option(indicators::option::PrefixText{
111+
pad_string(Repo2Engine(it.id)) + "100%"});
99112
(*bars)[i].set_progress(100);
100113

101114
CTL_INF("Download success");

engine/cli/utils/download_progress.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class DownloadProgress {
1010
public:
1111
bool Connect(const std::string& host, int port);
1212

13-
bool Handle(const std::string& id);
13+
bool Handle(const DownloadType& event_type);
1414

1515
void ForceStop() { force_stop_ = true; }
1616

0 commit comments

Comments
 (0)