Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions engine/cli/commands/engine_install_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@ bool EngineInstallCmd::Exec(const std::string& engine,
DownloadProgress dp;
dp.Connect(host_, port_);
// engine can be small, so need to start ws first
auto dp_res = std::async(std::launch::deferred, [&dp, &engine] {
return dp.Handle(DownloadType::Engine);
auto dp_res = std::async(std::launch::deferred, [&dp] {
bool need_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (need_cuda_download) {
return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit});
} else {
return dp.Handle({DownloadType::Engine});
}
});

auto versions_url = url_parser::Url{
Expand Down Expand Up @@ -133,12 +138,6 @@ bool EngineInstallCmd::Exec(const std::string& engine,
if (!dp_res.get())
return false;

bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (check_cuda_download) {
if (!dp.Handle(DownloadType::CudaToolkit))
return false;
}

CLI_LOG("Engine " << engine << " downloaded successfully!")
return true;
}
Expand All @@ -147,8 +146,14 @@ bool EngineInstallCmd::Exec(const std::string& engine,
DownloadProgress dp;
dp.Connect(host_, port_);
// engine can be small, so need to start ws first
auto dp_res = std::async(std::launch::deferred,
[&dp] { return dp.Handle(DownloadType::Engine); });
auto dp_res = std::async(std::launch::deferred, [&dp] {
bool need_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (need_cuda_download) {
return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit});
} else {
return dp.Handle({DownloadType::Engine});
}
});

auto install_url = url_parser::Url{
.protocol = "http",
Expand Down Expand Up @@ -183,12 +188,6 @@ bool EngineInstallCmd::Exec(const std::string& engine,
if (!dp_res.get())
return false;

bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (check_cuda_download) {
if (!dp.Handle(DownloadType::CudaToolkit))
return false;
}

CLI_LOG("Engine " << engine << " downloaded successfully!")
return true;
}
Expand Down
15 changes: 7 additions & 8 deletions engine/cli/commands/engine_update_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,13 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port,
DownloadProgress dp;
dp.Connect(host, port);
// engine can be small, so need to start ws first
auto dp_res = std::async(std::launch::deferred, [&dp, &engine] {
return dp.Handle(DownloadType::Engine);
auto dp_res = std::async(std::launch::deferred, [&dp] {
bool need_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (need_cuda_download) {
return dp.Handle({DownloadType::Engine, DownloadType::CudaToolkit});
} else {
return dp.Handle({DownloadType::Engine});
}
});

auto update_url = url_parser::Url{
Expand All @@ -48,12 +53,6 @@ bool EngineUpdateCmd::Exec(const std::string& host, int port,
if (!dp_res.get())
return false;

bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (check_cuda_download) {
if (!dp.Handle(DownloadType::CudaToolkit))
return false;
}

CLI_LOG("Engine " << engine << " updated successfully!")
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion engine/cli/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ std::optional<std::string> ModelPullCmd::Exec(const std::string& host, int port,
reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
dp.Connect(host, port);
if (!dp.Handle(DownloadType::Model))
if (!dp.Handle({DownloadType::Model}))
return std::nullopt;
if (force_stop)
return std::nullopt;
Expand Down
81 changes: 49 additions & 32 deletions engine/cli/utils/download_progress.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ bool DownloadProgress::Connect(const std::string& host, int port) {
return true;
}

bool DownloadProgress::Handle(const DownloadType& event_type) {
bool DownloadProgress::Handle(
const std::unordered_set<DownloadType>& event_type) {
assert(!!ws_);
#if defined(_WIN32)
HANDLE h_out = GetStdHandle(STD_OUTPUT_HANDLE);
Expand All @@ -50,10 +51,14 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
}
}
#endif
status_ = DownloadStatus::DownloadStarted;
for (auto et : event_type) {
status_[et] = DownloadStatus::DownloadStarted;
}
std::unique_ptr<indicators::DynamicProgress<indicators::ProgressBar>> bars;

std::vector<std::unique_ptr<indicators::ProgressBar>> items;
std::unordered_map<std::string,
std::pair<int, std::unique_ptr<indicators::ProgressBar>>>
items;
indicators::show_console_cursor(false);
auto start = std::chrono::steady_clock::now();
auto handle_message = [this, &bars, &items, event_type,
Expand All @@ -78,22 +83,28 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
auto ev = cortex::event::GetDownloadEventFromJson(
json_helper::ParseJsonString(message));
// Ignore other task type
if (ev.download_task_.type != event_type) {
if (event_type.find(ev.download_task_.type) == event_type.end()) {
return;
}
auto now = std::chrono::steady_clock::now();
if (!bars) {
bars = std::make_unique<
indicators::DynamicProgress<indicators::ProgressBar>>();
for (auto& i : ev.download_task_.items) {
items.emplace_back(std::make_unique<indicators::ProgressBar>(
indicators::option::BarWidth{50}, indicators::option::Start{"["},
indicators::option::Fill{"="}, indicators::option::Lead{">"},
indicators::option::End{"]"},
indicators::option::PrefixText{pad_string(Repo2Engine(i.id))},
indicators::option::ForegroundColor{indicators::Color::white},
indicators::option::ShowRemainingTime{false}));
bars->push_back(*(items.back()));
}
for (auto& i : ev.download_task_.items) {
if (items.find(i.id) == items.end()) {
auto idx = items.size();
items[i.id] = std::pair(
idx,
std::make_unique<indicators::ProgressBar>(
indicators::option::BarWidth{50},
indicators::option::Start{"["}, indicators::option::Fill{"="},
indicators::option::Lead{">"}, indicators::option::End{"]"},
indicators::option::PrefixText{pad_string(Repo2Engine(i.id))},
indicators::option::ForegroundColor{indicators::Color::white},
indicators::option::ShowRemainingTime{false}));

bars->push_back(*(items.at(i.id).second));
}
}
for (int i = 0; i < ev.download_task_.items.size(); i++) {
Expand All @@ -113,32 +124,36 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
(total - downloaded) / bytes_per_sec);
}

(*bars)[i].set_option(indicators::option::PrefixText{
pad_string(Repo2Engine(it.id)) +
std::to_string(int(static_cast<double>(downloaded) / total * 100)) +
'%'});
(*bars)[i].set_progress(
(*bars)[items.at(it.id).first].set_option(
indicators::option::PrefixText{
pad_string(Repo2Engine(it.id)) +
std::to_string(
int(static_cast<double>(downloaded) / total * 100)) +
'%'});
(*bars)[items.at(it.id).first].set_progress(
int(static_cast<double>(downloaded) / total * 100));
(*bars)[i].set_option(indicators::option::PostfixText{
time_remaining + " " +
format_utils::BytesToHumanReadable(downloaded) + "/" +
format_utils::BytesToHumanReadable(total)});
(*bars)[items.at(it.id).first].set_option(
indicators::option::PostfixText{
time_remaining + " " +
format_utils::BytesToHumanReadable(downloaded) + "/" +
format_utils::BytesToHumanReadable(total)});
} else if (ev.type_ == DownloadStatus::DownloadSuccess) {
uint64_t total =
it.bytes.value_or(std::numeric_limits<uint64_t>::max());
(*bars)[i].set_progress(100);
(*bars)[items.at(it.id).first].set_progress(100);
auto total_str = format_utils::BytesToHumanReadable(total);
(*bars)[i].set_option(indicators::option::PostfixText{
"00m:00s " + total_str + "/" + total_str});
(*bars)[i].set_option(indicators::option::PrefixText{
pad_string(Repo2Engine(it.id)) + "100%"});
(*bars)[i].set_progress(100);
(*bars)[items.at(it.id).first].set_option(
indicators::option::PostfixText{"00m:00s " + total_str + "/" +
total_str});
(*bars)[items.at(it.id).first].set_option(
indicators::option::PrefixText{pad_string(Repo2Engine(it.id)) +
"100%"});
(*bars)[items.at(it.id).first].set_progress(100);

CTL_INF("Download success");
}
status_[ev.download_task_.type] = ev.type_;
}

status_ = ev.type_;
};

while (ws_->getReadyState() != easywsclient::WebSocket::CLOSED &&
Expand All @@ -152,7 +167,9 @@ bool DownloadProgress::Handle(const DownloadType& event_type) {
SetConsoleMode(h_out, dw_original_out_mode);
}
#endif
if (status_ == DownloadStatus::DownloadError)
return false;
for (auto const& [_, v] : status_) {
if (v == DownloadStatus::DownloadError)
return false;
}
return true;
}
17 changes: 12 additions & 5 deletions engine/cli/utils/download_progress.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <atomic>
#include <memory>
#include <string>
#include <unordered_set>
#include "common/event.h"
#include "easywsclient.hpp"

Expand All @@ -10,19 +11,25 @@ class DownloadProgress {
public:
bool Connect(const std::string& host, int port);

bool Handle(const DownloadType& event_type);
bool Handle(const std::unordered_set<DownloadType>& event_type);

void ForceStop() { force_stop_ = true; }

private:
bool should_stop() const {
return (status_ != DownloadStatus::DownloadStarted &&
status_ != DownloadStatus::DownloadUpdated) ||
force_stop_;
bool should_stop = true;
for (auto const& [_, v] : status_) {
should_stop &= (v == DownloadStatus::DownloadSuccess);
}
for (auto const& [_, v] : status_) {
should_stop |= (v == DownloadStatus::DownloadError ||
v == DownloadStatus::DownloadStopped);
}
return should_stop || force_stop_;
}

private:
std::unique_ptr<easywsclient::WebSocket> ws_;
std::atomic<DownloadStatus> status_ = DownloadStatus::DownloadStarted;
std::unordered_map<DownloadType, std::atomic<DownloadStatus>> status_;
std::atomic<bool> force_stop_ = false;
};
6 changes: 4 additions & 2 deletions engine/e2e-test/test_api_model_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def setup_and_teardown(self):
success = start_server()
if not success:
raise Exception("Failed to start server")
requests.post("http://localhost:3928/v1/engines/llama-cpp")
run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60)
run("Delete model", ["models", "delete", "tinyllama:gguf"])
run(
"Pull model",
Expand All @@ -27,5 +27,7 @@ def setup_and_teardown(self):

def test_models_start_should_be_successful(self):
json_body = {"model": "tinyllama:gguf"}
response = requests.post("http://localhost:3928/v1/models/start", json=json_body)
response = requests.post(
"http://localhost:3928/v1/models/start", json=json_body
)
assert response.status_code == 200, f"status_code: {response.status_code}"
10 changes: 6 additions & 4 deletions engine/e2e-test/test_api_model_stop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import requests
from test_runner import start_server, stop_server
from test_runner import run, start_server, stop_server


class TestApiModelStop:
Expand All @@ -13,16 +13,18 @@ def setup_and_teardown(self):
if not success:
raise Exception("Failed to start server")

requests.post("http://localhost:3928/engines/llama-cpp")
run("Install engine", ["engines", "install", "llama-cpp"], 5 * 60)
yield

requests.delete("http://localhost:3928/engines/llama-cpp")
run("Uninstall engine", ["engines", "uninstall", "llama-cpp"])
# Teardown
stop_server()

def test_models_stop_should_be_successful(self):
json_body = {"model": "tinyllama:gguf"}
response = requests.post("http://localhost:3928/v1/models/start", json=json_body)
response = requests.post(
"http://localhost:3928/v1/models/start", json=json_body
)
assert response.status_code == 200, f"status_code: {response.status_code}"
response = requests.post("http://localhost:3928/v1/models/stop", json=json_body)
assert response.status_code == 200, f"status_code: {response.status_code}"
Loading
Loading