From 1eb34b238488ea791821cb7ca3941d112af8fc5d Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 24 Oct 2024 08:50:05 +0700 Subject: [PATCH 01/24] fix: add ws and indicators --- engine/cli/CMakeLists.txt | 1 + engine/cli/command_line_parser.cc | 4 +- engine/cli/commands/model_pull_cmd.cc | 109 +- engine/cli/commands/model_pull_cmd.h | 2 +- engine/cli/utils/easywsclient.cc | 594 +++++ engine/cli/utils/easywsclient.hpp | 85 + engine/cli/utils/indicators.hpp | 3257 +++++++++++++++++++++++++ engine/common/download_task.h | 66 + engine/common/event.h | 28 + engine/services/model_service.cc | 2 +- engine/test/components/test_event.cc | 50 + 11 files changed, 4191 insertions(+), 7 deletions(-) create mode 100644 engine/cli/utils/easywsclient.cc create mode 100644 engine/cli/utils/easywsclient.hpp create mode 100644 engine/cli/utils/indicators.hpp create mode 100644 engine/test/components/test_event.cc diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 11e2c384b..947bd9347 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -80,6 +80,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/engine_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc ) target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index 23b4c263f..faed315f8 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -130,7 +130,9 @@ void CommandLineParser::SetupCommonCommands() { return; } try { - commands::ModelPullCmd(download_service_).Exec(cml_data_.model_id); + commands::ModelPullCmd(download_service_) + .Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id); } catch (const std::exception& e) { CLI_LOG(e.what()); } diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 4ec5344bb..386380397 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -1,11 +1,112 @@ #include "model_pull_cmd.h" +#include +#include "cli/utils/easywsclient.hpp" +#include "cli/utils/indicators.hpp" +#include "common/event.h" +#include "server_start_cmd.h" +#include "utils/format_utils.h" +#include "utils/json_helper.h" #include "utils/logging_utils.h" namespace commands { -void ModelPullCmd::Exec(const std::string& input) { - auto result = model_service_.DownloadModel(input); - if (result.has_error()) { - CLI_LOG(result.error()); +void ModelPullCmd::Exec(const std::string& host, int port, + const std::string& input) { + // Start server if server is not started yet + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host, port)) { + return; + } } + + httplib::Client cli(host + ":" + std::to_string(port)); + Json::Value json_data; + json_data["model"] = input; + auto data_str = json_data.toStyledString(); + cli.set_read_timeout(std::chrono::seconds(60)); + auto res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), + data_str.size(), "application/json"); + + if (res) { + if (res->status == httplib::StatusCode::OK_200) { + // CLI_LOG("OK"); + } else { + CTL_ERR("Error:"); + return; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return; + } + + std::unique_ptr> + bars; + + std::vector> items; + + auto handle_message = [&bars, &items](const std::string& message) { + // std::cout << message << std::endl; + + auto pad_string = [](const std::string& str, + size_t max_length = 20) -> std::string { + // Check the length of the input string + if (str.length() >= max_length) { + return str.substr( + 0, max_length); // Return truncated string if it's too long + } + + // Calculate the number of spaces needed + size_t padding_size = max_length - str.length(); + + // Create a new string with the original string followed by spaces + return str + std::string(padding_size, ' '); + }; + + auto ev = cortex::event::GetDownloadEventFromJson( + json_helper::ParseJsonString(message)); + // std::cout << downloaded << " " << total << std::endl; + if (!bars) { + bars = std::make_unique< + indicators::DynamicProgress>(); + for (auto& i : ev.download_task_.items) { + items.emplace_back(std::make_unique( + indicators::option::BarWidth{50}, indicators::option::Start{"|"}, + // indicators::option::Fill{"■"}, indicators::option::Lead{"■"}, + // indicators::option::Remainder{" "}, + indicators::option::End{"|"}, indicators::option::PrefixText{pad_string(i.id)}, + indicators::option::PostfixText{"Downloading files"}, + indicators::option::ForegroundColor{indicators::Color::white}, + indicators::option::ShowRemainingTime{true}, + indicators::option::FontStyles{std::vector{ + indicators::FontStyle::bold}})); + bars->push_back(*(items.back())); + } + } else { + for (int i = 0; i < ev.download_task_.items.size(); i++) { + auto& it = ev.download_task_.items[i]; + uint64_t downloaded = it.downloadedBytes.value_or(0); + uint64_t total = it.bytes.value_or(9999); + (*bars)[i].set_progress(static_cast(downloaded) / total * 100); + (*bars)[i].set_option(indicators::option::PostfixText{ + format_utils::BytesToHumanReadable(downloaded) + "/" + + format_utils::BytesToHumanReadable(total)}); + } + } + }; + + auto ws = easywsclient::WebSocket::from_url("ws://" + host + ":" + + std::to_string(port) + "/events"); + // auto result = model_service_.DownloadModel(input); + // if (result.has_error()) { + // CLI_LOG(result.error()); + // } + while (ws->getReadyState() != easywsclient::WebSocket::CLOSED) { + ws->poll(); + ws->dispatch(handle_message); + } + std::cout << "Done" << std::endl; + delete ws; } }; // namespace commands diff --git a/engine/cli/commands/model_pull_cmd.h b/engine/cli/commands/model_pull_cmd.h index 3586b3cd4..444fc0bde 100644 --- a/engine/cli/commands/model_pull_cmd.h +++ b/engine/cli/commands/model_pull_cmd.h @@ -8,7 +8,7 @@ class ModelPullCmd { public: explicit ModelPullCmd(std::shared_ptr download_service) : model_service_{ModelService(download_service)} {}; - void Exec(const std::string& input); + void Exec(const std::string& host, int port, const std::string& input); private: ModelService model_service_; diff --git a/engine/cli/utils/easywsclient.cc b/engine/cli/utils/easywsclient.cc new file mode 100644 index 000000000..5c6ed38e8 --- /dev/null +++ b/engine/cli/utils/easywsclient.cc @@ -0,0 +1,594 @@ + +#ifdef _WIN32 +#if defined(_MSC_VER) && !defined(_CRT_SECURE_NO_WARNINGS) +#define _CRT_SECURE_NO_WARNINGS // _CRT_SECURE_NO_WARNINGS for sscanf errors in MSVC2013 Express +#endif +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include +#include +#include +#pragma comment(lib, "ws2_32") +#include +#include +#include +#include +#include +#ifndef _SSIZE_T_DEFINED +typedef int ssize_t; +#define _SSIZE_T_DEFINED +#endif +#ifndef _SOCKET_T_DEFINED +typedef SOCKET socket_t; +#define _SOCKET_T_DEFINED +#endif +#ifndef snprintf +#define snprintf _snprintf_s +#endif +#if _MSC_VER >= 1600 +// vs2010 or later +#include +#else +typedef __int8 int8_t; +typedef unsigned __int8 uint8_t; +typedef __int32 int32_t; +typedef unsigned __int32 uint32_t; +typedef __int64 int64_t; +typedef unsigned __int64 uint64_t; +#endif +#define socketerrno WSAGetLastError() +#define SOCKET_EAGAIN_EINPROGRESS WSAEINPROGRESS +#define SOCKET_EWOULDBLOCK WSAEWOULDBLOCK +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef _SOCKET_T_DEFINED +typedef int socket_t; +#define _SOCKET_T_DEFINED +#endif +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#ifndef SOCKET_ERROR +#define SOCKET_ERROR (-1) +#endif +#define closesocket(s) ::close(s) +#include +#define socketerrno errno +#define SOCKET_EAGAIN_EINPROGRESS EAGAIN +#define SOCKET_EWOULDBLOCK EWOULDBLOCK +#endif + +#include +#include + +#include "easywsclient.hpp" + +using easywsclient::BytesCallback_Imp; +using easywsclient::Callback_Imp; + +namespace { // private module-only namespace + +socket_t hostname_connect(const std::string& hostname, int port) { + struct addrinfo hints; + struct addrinfo* result; + struct addrinfo* p; + int ret; + socket_t sockfd = INVALID_SOCKET; + char sport[16]; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + snprintf(sport, 16, "%d", port); + if ((ret = getaddrinfo(hostname.c_str(), sport, &hints, &result)) != 0) { + fprintf(stderr, "getaddrinfo: %s\n", gai_strerror(ret)); + return 1; + } + for (p = result; p != NULL; p = p->ai_next) { + sockfd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); + if (sockfd == INVALID_SOCKET) { + continue; + } + if (connect(sockfd, p->ai_addr, p->ai_addrlen) != SOCKET_ERROR) { + break; + } + closesocket(sockfd); + sockfd = INVALID_SOCKET; + } + freeaddrinfo(result); + return sockfd; +} + +class _DummyWebSocket : public easywsclient::WebSocket { + public: + void poll(int timeout) {} + void send(const std::string& message) {} + void sendBinary(const std::string& message) {} + void sendBinary(const std::vector& message) {} + void sendPing() {} + void close() {} + readyStateValues getReadyState() const { return CLOSED; } + void _dispatch(Callback_Imp& callable) {} + void _dispatchBinary(BytesCallback_Imp& callable) {} +}; + +class _RealWebSocket : public easywsclient::WebSocket { + public: + // http://tools.ietf.org/html/rfc6455#section-5.2 Base Framing Protocol + // + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + // | Payload Data continued ... | + // +---------------------------------------------------------------+ + struct wsheader_type { + unsigned header_size; + bool fin; + bool mask; + enum opcode_type { + CONTINUATION = 0x0, + TEXT_FRAME = 0x1, + BINARY_FRAME = 0x2, + CLOSE = 8, + PING = 9, + PONG = 0xa, + } opcode; + int N0; + uint64_t N; + uint8_t masking_key[4]; + }; + + std::vector rxbuf; + std::vector txbuf; + std::vector receivedData; + + socket_t sockfd; + readyStateValues readyState; + bool useMask; + bool isRxBad; + + _RealWebSocket(socket_t sockfd, bool useMask) + : sockfd(sockfd), readyState(OPEN), useMask(useMask), isRxBad(false) {} + + readyStateValues getReadyState() const { return readyState; } + + void poll(int timeout) { // timeout in milliseconds + if (readyState == CLOSED) { + if (timeout > 0) { + timeval tv = {timeout / 1000, (timeout % 1000) * 1000}; + select(0, NULL, NULL, NULL, &tv); + } + return; + } + if (timeout != 0) { + fd_set rfds; + fd_set wfds; + timeval tv = {timeout / 1000, (timeout % 1000) * 1000}; + FD_ZERO(&rfds); + FD_ZERO(&wfds); + FD_SET(sockfd, &rfds); + if (txbuf.size()) { + FD_SET(sockfd, &wfds); + } + select(sockfd + 1, &rfds, &wfds, 0, timeout > 0 ? &tv : 0); + } + while (true) { + // FD_ISSET(0, &rfds) will be true + int N = rxbuf.size(); + ssize_t ret; + rxbuf.resize(N + 1500); + ret = recv(sockfd, (char*)&rxbuf[0] + N, 1500, 0); + if (false) { + } else if (ret < 0 && (socketerrno == SOCKET_EWOULDBLOCK || + socketerrno == SOCKET_EAGAIN_EINPROGRESS)) { + rxbuf.resize(N); + break; + } else if (ret <= 0) { + rxbuf.resize(N); + closesocket(sockfd); + readyState = CLOSED; + fputs(ret < 0 ? "Connection error!\n" : "Connection closed!\n", stderr); + break; + } else { + rxbuf.resize(N + ret); + } + } + while (txbuf.size()) { + int ret = ::send(sockfd, (char*)&txbuf[0], txbuf.size(), 0); + if (false) { + } // ?? + else if (ret < 0 && (socketerrno == SOCKET_EWOULDBLOCK || + socketerrno == SOCKET_EAGAIN_EINPROGRESS)) { + break; + } else if (ret <= 0) { + closesocket(sockfd); + readyState = CLOSED; + fputs(ret < 0 ? "Connection error!\n" : "Connection closed!\n", stderr); + break; + } else { + txbuf.erase(txbuf.begin(), txbuf.begin() + ret); + } + } + if (!txbuf.size() && readyState == CLOSING) { + closesocket(sockfd); + readyState = CLOSED; + } + } + + // Callable must have signature: void(const std::string & message). + // Should work with C functions, C++ functors, and C++11 std::function and + // lambda: + //template + //void dispatch(Callable callable) + virtual void _dispatch(Callback_Imp& callable) { + struct CallbackAdapter : public BytesCallback_Imp + // Adapt void(const std::string&) to void(const std::string&) + { + Callback_Imp& callable; + CallbackAdapter(Callback_Imp& callable) : callable(callable) {} + void operator()(const std::vector& message) { + std::string stringMessage(message.begin(), message.end()); + callable(stringMessage); + } + }; + CallbackAdapter bytesCallback(callable); + _dispatchBinary(bytesCallback); + } + + virtual void _dispatchBinary(BytesCallback_Imp& callable) { + // TODO: consider acquiring a lock on rxbuf... + if (isRxBad) { + return; + } + while (true) { + wsheader_type ws; + if (rxbuf.size() < 2) { + return; /* Need at least 2 */ + } + const uint8_t* data = (uint8_t*)&rxbuf[0]; // peek, but don't consume + ws.fin = (data[0] & 0x80) == 0x80; + ws.opcode = (wsheader_type::opcode_type)(data[0] & 0x0f); + ws.mask = (data[1] & 0x80) == 0x80; + ws.N0 = (data[1] & 0x7f); + ws.header_size = 2 + (ws.N0 == 126 ? 2 : 0) + (ws.N0 == 127 ? 8 : 0) + + (ws.mask ? 4 : 0); + if (rxbuf.size() < ws.header_size) { + return; /* Need: ws.header_size - rxbuf.size() */ + } + int i = 0; + if (ws.N0 < 126) { + ws.N = ws.N0; + i = 2; + } else if (ws.N0 == 126) { + ws.N = 0; + ws.N |= ((uint64_t)data[2]) << 8; + ws.N |= ((uint64_t)data[3]) << 0; + i = 4; + } else if (ws.N0 == 127) { + ws.N = 0; + ws.N |= ((uint64_t)data[2]) << 56; + ws.N |= ((uint64_t)data[3]) << 48; + ws.N |= ((uint64_t)data[4]) << 40; + ws.N |= ((uint64_t)data[5]) << 32; + ws.N |= ((uint64_t)data[6]) << 24; + ws.N |= ((uint64_t)data[7]) << 16; + ws.N |= ((uint64_t)data[8]) << 8; + ws.N |= ((uint64_t)data[9]) << 0; + i = 10; + if (ws.N & 0x8000000000000000ull) { + // https://tools.ietf.org/html/rfc6455 writes the "the most + // significant bit MUST be 0." + // + // We can't drop the frame, because (1) we don't we don't + // know how much data to skip over to find the next header, + // and (2) this would be an impractically long length, even + // if it were valid. So just close() and return immediately + // for now. + isRxBad = true; + fprintf(stderr, "ERROR: Frame has invalid frame length. Closing.\n"); + close(); + return; + } + } + if (ws.mask) { + ws.masking_key[0] = ((uint8_t)data[i + 0]) << 0; + ws.masking_key[1] = ((uint8_t)data[i + 1]) << 0; + ws.masking_key[2] = ((uint8_t)data[i + 2]) << 0; + ws.masking_key[3] = ((uint8_t)data[i + 3]) << 0; + } else { + ws.masking_key[0] = 0; + ws.masking_key[1] = 0; + ws.masking_key[2] = 0; + ws.masking_key[3] = 0; + } + + // Note: The checks above should hopefully ensure this addition + // cannot overflow: + if (rxbuf.size() < ws.header_size + ws.N) { + return; /* Need: ws.header_size+ws.N - rxbuf.size() */ + } + + // We got a whole message, now do something with it: + if (false) { + } else if (ws.opcode == wsheader_type::TEXT_FRAME || + ws.opcode == wsheader_type::BINARY_FRAME || + ws.opcode == wsheader_type::CONTINUATION) { + if (ws.mask) { + for (size_t i = 0; i != ws.N; ++i) { + rxbuf[i + ws.header_size] ^= ws.masking_key[i & 0x3]; + } + } + receivedData.insert( + receivedData.end(), rxbuf.begin() + ws.header_size, + rxbuf.begin() + ws.header_size + (size_t)ws.N); // just feed + if (ws.fin) { + callable((const std::vector)receivedData); + receivedData.erase(receivedData.begin(), receivedData.end()); + std::vector().swap(receivedData); // free memory + } + } else if (ws.opcode == wsheader_type::PING) { + if (ws.mask) { + for (size_t i = 0; i != ws.N; ++i) { + rxbuf[i + ws.header_size] ^= ws.masking_key[i & 0x3]; + } + } + std::string data(rxbuf.begin() + ws.header_size, + rxbuf.begin() + ws.header_size + (size_t)ws.N); + sendData(wsheader_type::PONG, data.size(), data.begin(), data.end()); + } else if (ws.opcode == wsheader_type::PONG) { + } else if (ws.opcode == wsheader_type::CLOSE) { + close(); + } else { + fprintf(stderr, "ERROR: Got unexpected WebSocket message.\n"); + close(); + } + + rxbuf.erase(rxbuf.begin(), rxbuf.begin() + ws.header_size + (size_t)ws.N); + } + } + + void sendPing() { + std::string empty; + sendData(wsheader_type::PING, empty.size(), empty.begin(), empty.end()); + } + + void send(const std::string& message) { + sendData(wsheader_type::TEXT_FRAME, message.size(), message.begin(), + message.end()); + } + + void sendBinary(const std::string& message) { + sendData(wsheader_type::BINARY_FRAME, message.size(), message.begin(), + message.end()); + } + + void sendBinary(const std::vector& message) { + sendData(wsheader_type::BINARY_FRAME, message.size(), message.begin(), + message.end()); + } + + template + void sendData(wsheader_type::opcode_type type, uint64_t message_size, + Iterator message_begin, Iterator message_end) { + // TODO: + // Masking key should (must) be derived from a high quality random + // number generator, to mitigate attacks on non-WebSocket friendly + // middleware: + const uint8_t masking_key[4] = {0x12, 0x34, 0x56, 0x78}; + // TODO: consider acquiring a lock on txbuf... + if (readyState == CLOSING || readyState == CLOSED) { + return; + } + std::vector header; + header.assign(2 + (message_size >= 126 ? 2 : 0) + + (message_size >= 65536 ? 6 : 0) + (useMask ? 4 : 0), + 0); + header[0] = 0x80 | type; + if (false) { + } else if (message_size < 126) { + header[1] = (message_size & 0xff) | (useMask ? 0x80 : 0); + if (useMask) { + header[2] = masking_key[0]; + header[3] = masking_key[1]; + header[4] = masking_key[2]; + header[5] = masking_key[3]; + } + } else if (message_size < 65536) { + header[1] = 126 | (useMask ? 0x80 : 0); + header[2] = (message_size >> 8) & 0xff; + header[3] = (message_size >> 0) & 0xff; + if (useMask) { + header[4] = masking_key[0]; + header[5] = masking_key[1]; + header[6] = masking_key[2]; + header[7] = masking_key[3]; + } + } else { // TODO: run coverage testing here + header[1] = 127 | (useMask ? 0x80 : 0); + header[2] = (message_size >> 56) & 0xff; + header[3] = (message_size >> 48) & 0xff; + header[4] = (message_size >> 40) & 0xff; + header[5] = (message_size >> 32) & 0xff; + header[6] = (message_size >> 24) & 0xff; + header[7] = (message_size >> 16) & 0xff; + header[8] = (message_size >> 8) & 0xff; + header[9] = (message_size >> 0) & 0xff; + if (useMask) { + header[10] = masking_key[0]; + header[11] = masking_key[1]; + header[12] = masking_key[2]; + header[13] = masking_key[3]; + } + } + // N.B. - txbuf will keep growing until it can be transmitted over the socket: + txbuf.insert(txbuf.end(), header.begin(), header.end()); + txbuf.insert(txbuf.end(), message_begin, message_end); + if (useMask) { + size_t message_offset = txbuf.size() - message_size; + for (size_t i = 0; i != message_size; ++i) { + txbuf[message_offset + i] ^= masking_key[i & 0x3]; + } + } + } + + void close() { + if (readyState == CLOSING || readyState == CLOSED) { + return; + } + readyState = CLOSING; + uint8_t closeFrame[6] = {0x88, 0x80, 0x00, 0x00, + 0x00, 0x00}; // last 4 bytes are a masking key + std::vector header(closeFrame, closeFrame + 6); + txbuf.insert(txbuf.end(), header.begin(), header.end()); + } +}; + +easywsclient::WebSocket::pointer from_url(const std::string& url, bool useMask, + const std::string& origin) { + char host[512]; + int port; + char path[512]; + if (url.size() >= 512) { + fprintf(stderr, "ERROR: url size limit exceeded: %s\n", url.c_str()); + return NULL; + } + if (origin.size() >= 200) { + fprintf(stderr, "ERROR: origin size limit exceeded: %s\n", origin.c_str()); + return NULL; + } + if (false) { + } else if (sscanf(url.c_str(), "ws://%[^:/]:%d/%s", host, &port, path) == 3) { + } else if (sscanf(url.c_str(), "ws://%[^:/]/%s", host, path) == 2) { + port = 80; + } else if (sscanf(url.c_str(), "ws://%[^:/]:%d", host, &port) == 2) { + path[0] = '\0'; + } else if (sscanf(url.c_str(), "ws://%[^:/]", host) == 1) { + port = 80; + path[0] = '\0'; + } else { + fprintf(stderr, "ERROR: Could not parse WebSocket url: %s\n", url.c_str()); + return NULL; + } + //fprintf(stderr, "easywsclient: connecting: host=%s port=%d path=/%s\n", host, port, path); + socket_t sockfd = hostname_connect(host, port); + if (sockfd == INVALID_SOCKET) { + fprintf(stderr, "Unable to connect to %s:%d\n", host, port); + return NULL; + } + { + // XXX: this should be done non-blocking, + char line[1024]; + int status; + int i; + snprintf(line, 1024, "GET /%s HTTP/1.1\r\n", path); + ::send(sockfd, line, strlen(line), 0); + if (port == 80) { + snprintf(line, 1024, "Host: %s\r\n", host); + ::send(sockfd, line, strlen(line), 0); + } else { + snprintf(line, 1024, "Host: %s:%d\r\n", host, port); + ::send(sockfd, line, strlen(line), 0); + } + snprintf(line, 1024, "Upgrade: websocket\r\n"); + ::send(sockfd, line, strlen(line), 0); + snprintf(line, 1024, "Connection: Upgrade\r\n"); + ::send(sockfd, line, strlen(line), 0); + if (!origin.empty()) { + snprintf(line, 1024, "Origin: %s\r\n", origin.c_str()); + ::send(sockfd, line, strlen(line), 0); + } + snprintf(line, 1024, "Sec-WebSocket-Key: x3JJHMbDL1EzLkh9GBhXDw==\r\n"); + ::send(sockfd, line, strlen(line), 0); + snprintf(line, 1024, "Sec-WebSocket-Version: 13\r\n"); + ::send(sockfd, line, strlen(line), 0); + snprintf(line, 1024, "\r\n"); + ::send(sockfd, line, strlen(line), 0); + for (i = 0; + i < 2 || (i < 1023 && line[i - 2] != '\r' && line[i - 1] != '\n'); + ++i) { + if (recv(sockfd, line + i, 1, 0) == 0) { + return NULL; + } + } + line[i] = 0; + if (i == 1023) { + fprintf(stderr, "ERROR: Got invalid status line connecting to: %s\n", + url.c_str()); + return NULL; + } + if (sscanf(line, "HTTP/1.1 %d", &status) != 1 || status != 101) { + fprintf(stderr, "ERROR: Got bad status connecting to %s: %s", url.c_str(), + line); + return NULL; + } + // TODO: verify response headers, + while (true) { + for (i = 0; + i < 2 || (i < 1023 && line[i - 2] != '\r' && line[i - 1] != '\n'); + ++i) { + if (recv(sockfd, line + i, 1, 0) == 0) { + return NULL; + } + } + if (line[0] == '\r' && line[1] == '\n') { + break; + } + } + } + int flag = 1; + setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, + sizeof(flag)); // Disable Nagle's algorithm +#ifdef _WIN32 + u_long on = 1; + ioctlsocket(sockfd, FIONBIO, &on); +#else + fcntl(sockfd, F_SETFL, O_NONBLOCK); +#endif + //fprintf(stderr, "Connected to: %s\n", url.c_str()); + return easywsclient::WebSocket::pointer(new _RealWebSocket(sockfd, useMask)); +} + +} // namespace + +namespace easywsclient { + +WebSocket::pointer WebSocket::create_dummy() { + static pointer dummy = pointer(new _DummyWebSocket); + return dummy; +} + +WebSocket::pointer WebSocket::from_url(const std::string& url, + const std::string& origin) { + return ::from_url(url, true, origin); +} + +WebSocket::pointer WebSocket::from_url_no_mask(const std::string& url, + const std::string& origin) { + return ::from_url(url, false, origin); +} + +} // namespace easywsclient \ No newline at end of file diff --git a/engine/cli/utils/easywsclient.hpp b/engine/cli/utils/easywsclient.hpp new file mode 100644 index 000000000..1f0149d2c --- /dev/null +++ b/engine/cli/utils/easywsclient.hpp @@ -0,0 +1,85 @@ +#ifndef EASYWSCLIENT_HPP_20120819_MIOFVASDTNUASZDQPLFD +#define EASYWSCLIENT_HPP_20120819_MIOFVASDTNUASZDQPLFD + +// This code comes from: +// https://github.com/dhbaird/easywsclient +// +// To get the latest version: +// wget https://raw.github.com/dhbaird/easywsclient/master/easywsclient.hpp +// wget https://raw.github.com/dhbaird/easywsclient/master/easywsclient.cpp + +#include +#include + +namespace easywsclient { + +struct Callback_Imp { + virtual void operator()(const std::string& message) = 0; +}; +struct BytesCallback_Imp { + virtual void operator()(const std::vector& message) = 0; +}; + +class WebSocket { + public: + typedef WebSocket* pointer; + typedef enum readyStateValues { + CLOSING, + CLOSED, + CONNECTING, + OPEN + } readyStateValues; + + // Factories: + static pointer create_dummy(); + static pointer from_url(const std::string& url, + const std::string& origin = std::string()); + static pointer from_url_no_mask(const std::string& url, + const std::string& origin = std::string()); + + // Interfaces: + virtual ~WebSocket() {} + virtual void poll(int timeout = 0) = 0; // timeout in milliseconds + virtual void send(const std::string& message) = 0; + virtual void sendBinary(const std::string& message) = 0; + virtual void sendBinary(const std::vector& message) = 0; + virtual void sendPing() = 0; + virtual void close() = 0; + virtual readyStateValues getReadyState() const = 0; + + template + void dispatch(Callable callable) + // For callbacks that accept a string argument. + { // N.B. this is compatible with both C++11 lambdas, functors and C function pointers + struct _Callback : public Callback_Imp { + Callable& callable; + _Callback(Callable& callable) : callable(callable) {} + void operator()(const std::string& message) { callable(message); } + }; + _Callback callback(callable); + _dispatch(callback); + } + + template + void dispatchBinary(Callable callable) + // For callbacks that accept a std::vector argument. + { // N.B. this is compatible with both C++11 lambdas, functors and C function pointers + struct _Callback : public BytesCallback_Imp { + Callable& callable; + _Callback(Callable& callable) : callable(callable) {} + void operator()(const std::vector& message) { + callable(message); + } + }; + _Callback callback(callable); + _dispatchBinary(callback); + } + + protected: + virtual void _dispatch(Callback_Imp& callable) = 0; + virtual void _dispatchBinary(BytesCallback_Imp& callable) = 0; +}; + +} // namespace easywsclient + +#endif /* EASYWSCLIENT_HPP_20120819_MIOFVASDTNUASZDQPLFD */ \ No newline at end of file diff --git a/engine/cli/utils/indicators.hpp b/engine/cli/utils/indicators.hpp new file mode 100644 index 000000000..f034c9441 --- /dev/null +++ b/engine/cli/utils/indicators.hpp @@ -0,0 +1,3257 @@ + +#ifndef INDICATORS_COLOR +#define INDICATORS_COLOR + +namespace indicators { +enum class Color { + grey, + red, + green, + yellow, + blue, + magenta, + cyan, + white, + unspecified +}; +} + +#endif + +#ifndef INDICATORS_FONT_STYLE +#define INDICATORS_FONT_STYLE + +namespace indicators { +enum class FontStyle { + bold, + dark, + italic, + underline, + blink, + reverse, + concealed, + crossed +}; +} + +#endif + +#ifndef INDICATORS_PROGRESS_TYPE +#define INDICATORS_PROGRESS_TYPE + +namespace indicators { +enum class ProgressType { incremental, decremental }; +} + +#endif + +//! +//! termcolor +//! ~~~~~~~~~ +//! +//! termcolor is a header-only c++ library for printing colored messages +//! to the terminal. Written just for fun with a help of the Force. +//! +//! :copyright: (c) 2013 by Ihor Kalnytskyi +//! :license: BSD, see LICENSE for details +//! + +#ifndef TERMCOLOR_HPP_ +#define TERMCOLOR_HPP_ + +#include +#include +#include + +// Detect target's platform and set some macros in order to wrap platform +// specific code this library depends on. +#if defined(_WIN32) || defined(_WIN64) +#define TERMCOLOR_TARGET_WINDOWS +#elif defined(__unix__) || defined(__unix) || \ + (defined(__APPLE__) && defined(__MACH__)) +#define TERMCOLOR_TARGET_POSIX +#endif + +// If implementation has not been explicitly set, try to choose one based on +// target platform. +#if !defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) && \ + !defined(TERMCOLOR_USE_WINDOWS_API) && !defined(TERMCOLOR_USE_NOOP) +#if defined(TERMCOLOR_TARGET_POSIX) +#define TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES +#define TERMCOLOR_AUTODETECTED_IMPLEMENTATION +#elif defined(TERMCOLOR_TARGET_WINDOWS) +#define TERMCOLOR_USE_WINDOWS_API +#define TERMCOLOR_AUTODETECTED_IMPLEMENTATION +#endif +#endif + +// These headers provide isatty()/fileno() functions, which are used for +// testing whether a standard stream refers to the terminal. +#if defined(TERMCOLOR_TARGET_POSIX) +#include +#elif defined(TERMCOLOR_TARGET_WINDOWS) +#if defined(_MSC_VER) +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#endif +#include +#include +#endif + +namespace termcolor { +// Forward declaration of the `_internal` namespace. +// All comments are below. +namespace _internal { +inline int colorize_index(); +inline FILE* get_standard_stream(const std::ostream& stream); +inline bool is_colorized(std::ostream& stream); +inline bool is_atty(const std::ostream& stream); + +#if defined(TERMCOLOR_TARGET_WINDOWS) +inline void win_change_attributes(std::ostream& stream, int foreground, + int background = -1); +#endif +} // namespace _internal + +inline std::ostream& colorize(std::ostream& stream) { + stream.iword(_internal::colorize_index()) = 1L; + return stream; +} + +inline std::ostream& nocolorize(std::ostream& stream) { + stream.iword(_internal::colorize_index()) = 0L; + return stream; +} + +inline std::ostream& reset(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[00m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, -1); +#endif + } + return stream; +} + +inline std::ostream& bold(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[1m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +inline std::ostream& dark(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[2m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +inline std::ostream& italic(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[3m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +inline std::ostream& underline(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[4m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, COMMON_LVB_UNDERSCORE); +#endif + } + return stream; +} + +inline std::ostream& blink(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[5m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +inline std::ostream& reverse(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[7m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +inline std::ostream& concealed(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[8m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +inline std::ostream& crossed(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[9m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +template +inline std::ostream& color(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + char command[12]; + std::snprintf(command, sizeof(command), "\033[38;5;%dm", code); + stream << command; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +template +inline std::ostream& on_color(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + char command[12]; + std::snprintf(command, sizeof(command), "\033[48;5;%dm", code); + stream << command; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +template +inline std::ostream& color(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + char command[20]; + std::snprintf(command, sizeof(command), "\033[38;2;%d;%d;%dm", r, g, b); + stream << command; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +template +inline std::ostream& on_color(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + char command[20]; + std::snprintf(command, sizeof(command), "\033[48;2;%d;%d;%dm", r, g, b); + stream << command; +#elif defined(TERMCOLOR_USE_WINDOWS_API) +#endif + } + return stream; +} + +inline std::ostream& grey(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[30m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, + 0 // grey (black) + ); +#endif + } + return stream; +} + +inline std::ostream& red(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[31m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, FOREGROUND_RED); +#endif + } + return stream; +} + +inline std::ostream& green(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[32m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, FOREGROUND_GREEN); +#endif + } + return stream; +} + +inline std::ostream& yellow(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[33m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, FOREGROUND_GREEN | FOREGROUND_RED); +#endif + } + return stream; +} + +inline std::ostream& blue(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[34m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, FOREGROUND_BLUE); +#endif + } + return stream; +} + +inline std::ostream& magenta(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[35m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, FOREGROUND_BLUE | FOREGROUND_RED); +#endif + } + return stream; +} + +inline std::ostream& cyan(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[36m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, + FOREGROUND_BLUE | FOREGROUND_GREEN); +#endif + } + return stream; +} + +inline std::ostream& white(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[37m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED); +#endif + } + return stream; +} + +inline std::ostream& bright_grey(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[90m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, + 0 | FOREGROUND_INTENSITY // grey (black) + ); +#endif + } + return stream; +} + +inline std::ostream& bright_red(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[91m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, + FOREGROUND_RED | FOREGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& bright_green(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[92m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, + FOREGROUND_GREEN | FOREGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& bright_yellow(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[93m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& bright_blue(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[94m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, + FOREGROUND_BLUE | FOREGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& bright_magenta(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[95m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, FOREGROUND_BLUE | FOREGROUND_RED | FOREGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& bright_cyan(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[96m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& bright_white(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[97m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED | + FOREGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& on_grey(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[40m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + 0 // grey (black) + ); +#endif + } + return stream; +} + +inline std::ostream& on_red(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[41m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, BACKGROUND_RED); +#endif + } + return stream; +} + +inline std::ostream& on_green(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[42m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, BACKGROUND_GREEN); +#endif + } + return stream; +} + +inline std::ostream& on_yellow(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[43m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + BACKGROUND_GREEN | BACKGROUND_RED); +#endif + } + return stream; +} + +inline std::ostream& on_blue(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[44m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, BACKGROUND_BLUE); +#endif + } + return stream; +} + +inline std::ostream& on_magenta(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[45m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + BACKGROUND_BLUE | BACKGROUND_RED); +#endif + } + return stream; +} + +inline std::ostream& on_cyan(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[46m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + BACKGROUND_GREEN | BACKGROUND_BLUE); +#endif + } + return stream; +} + +inline std::ostream& on_white(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[47m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, -1, BACKGROUND_GREEN | BACKGROUND_BLUE | BACKGROUND_RED); +#endif + } + + return stream; +} + +inline std::ostream& on_bright_grey(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[100m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + 0 | BACKGROUND_INTENSITY // grey (black) + ); +#endif + } + return stream; +} + +inline std::ostream& on_bright_red(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[101m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + BACKGROUND_RED | BACKGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& on_bright_green(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[102m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + BACKGROUND_GREEN | BACKGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& on_bright_yellow(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[103m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, -1, BACKGROUND_GREEN | BACKGROUND_RED | BACKGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& on_bright_blue(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[104m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + BACKGROUND_BLUE | BACKGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& on_bright_magenta(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[105m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, -1, BACKGROUND_BLUE | BACKGROUND_RED | BACKGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& on_bright_cyan(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[106m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes( + stream, -1, BACKGROUND_GREEN | BACKGROUND_BLUE | BACKGROUND_INTENSITY); +#endif + } + return stream; +} + +inline std::ostream& on_bright_white(std::ostream& stream) { + if (_internal::is_colorized(stream)) { +#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) + stream << "\033[107m"; +#elif defined(TERMCOLOR_USE_WINDOWS_API) + _internal::win_change_attributes(stream, -1, + BACKGROUND_GREEN | BACKGROUND_BLUE | + BACKGROUND_RED | BACKGROUND_INTENSITY); +#endif + } + + return stream; +} + +//! Since C++ hasn't a way to hide something in the header from +//! the outer access, I have to introduce this namespace which +//! is used for internal purpose and should't be access from +//! the user code. +namespace _internal { +// An index to be used to access a private storage of I/O streams. See +// colorize / nocolorize I/O manipulators for details. Due to the fact +// that static variables ain't shared between translation units, inline +// function with local static variable is used to do the trick and share +// the variable value between translation units. +inline int colorize_index() { + static int colorize_index = std::ios_base::xalloc(); + return colorize_index; +} + +//! Since C++ hasn't a true way to extract stream handler +//! from the a given `std::ostream` object, I have to write +//! this kind of hack. +inline FILE* get_standard_stream(const std::ostream& stream) { + if (&stream == &std::cout) + return stdout; + else if ((&stream == &std::cerr) || (&stream == &std::clog)) + return stderr; + + return nullptr; +} + +// Say whether a given stream should be colorized or not. It's always +// true for ATTY streams and may be true for streams marked with +// colorize flag. +inline bool is_colorized(std::ostream& stream) { + return is_atty(stream) || static_cast(stream.iword(colorize_index())); +} + +//! Test whether a given `std::ostream` object refers to +//! a terminal. +inline bool is_atty(const std::ostream& stream) { + FILE* std_stream = get_standard_stream(stream); + + // Unfortunately, fileno() ends with segmentation fault + // if invalid file descriptor is passed. So we need to + // handle this case gracefully and assume it's not a tty + // if standard stream is not detected, and 0 is returned. + if (!std_stream) + return false; + +#if defined(TERMCOLOR_TARGET_POSIX) + return ::isatty(fileno(std_stream)); +#elif defined(TERMCOLOR_TARGET_WINDOWS) + return ::_isatty(_fileno(std_stream)); +#else + return false; +#endif +} + +#if defined(TERMCOLOR_TARGET_WINDOWS) +//! Change Windows Terminal colors attribute. If some +//! parameter is `-1` then attribute won't changed. +inline void win_change_attributes(std::ostream& stream, int foreground, + int background) { + // yeah, i know.. it's ugly, it's windows. + static WORD defaultAttributes = 0; + + // Windows doesn't have ANSI escape sequences and so we use special + // API to change Terminal output color. That means we can't + // manipulate colors by means of "std::stringstream" and hence + // should do nothing in this case. + if (!_internal::is_atty(stream)) + return; + + // get terminal handle + HANDLE hTerminal = INVALID_HANDLE_VALUE; + if (&stream == &std::cout) + hTerminal = GetStdHandle(STD_OUTPUT_HANDLE); + else if (&stream == &std::cerr) + hTerminal = GetStdHandle(STD_ERROR_HANDLE); + + // save default terminal attributes if it unsaved + if (!defaultAttributes) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(hTerminal, &info)) + return; + defaultAttributes = info.wAttributes; + } + + // restore all default settings + if (foreground == -1 && background == -1) { + SetConsoleTextAttribute(hTerminal, defaultAttributes); + return; + } + + // get current settings + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(hTerminal, &info)) + return; + + if (foreground != -1) { + info.wAttributes &= ~(info.wAttributes & 0x0F); + info.wAttributes |= static_cast(foreground); + } + + if (background != -1) { + info.wAttributes &= ~(info.wAttributes & 0xF0); + info.wAttributes |= static_cast(background); + } + + SetConsoleTextAttribute(hTerminal, info.wAttributes); +} +#endif // TERMCOLOR_TARGET_WINDOWS + +} // namespace _internal + +} // namespace termcolor + +#undef TERMCOLOR_TARGET_POSIX +#undef TERMCOLOR_TARGET_WINDOWS + +#if defined(TERMCOLOR_AUTODETECTED_IMPLEMENTATION) +#undef TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES +#undef TERMCOLOR_USE_WINDOWS_API +#endif + +#endif // TERMCOLOR_HPP_ + +#ifndef INDICATORS_TERMINAL_SIZE +#define INDICATORS_TERMINAL_SIZE +#include + +#if defined(_WIN32) +#include + +namespace indicators { + +static inline std::pair terminal_size() { + CONSOLE_SCREEN_BUFFER_INFO csbi; + int cols, rows; + GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); + cols = csbi.srWindow.Right - csbi.srWindow.Left + 1; + rows = csbi.srWindow.Bottom - csbi.srWindow.Top + 1; + return {static_cast(rows), static_cast(cols)}; +} + +static inline size_t terminal_width() { + return terminal_size().second; +} + +} // namespace indicators + +#else + +#include //ioctl() and TIOCGWINSZ +#include // for STDOUT_FILENO + +namespace indicators { + +static inline std::pair terminal_size() { + struct winsize size {}; + ioctl(STDOUT_FILENO, TIOCGWINSZ, &size); + return {static_cast(size.ws_row), static_cast(size.ws_col)}; +} + +static inline size_t terminal_width() { + return terminal_size().second; +} + +} // namespace indicators + +#endif + +#endif + +/* +Activity Indicators for Modern C++ +https://github.com/p-ranav/indicators + +Licensed under the MIT License . +SPDX-License-Identifier: MIT +Copyright (c) 2019 Dawid Pilarski . + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ +#ifndef INDICATORS_SETTING +#define INDICATORS_SETTING + +#include +// #include +// #include +// #include +#include +#include +#include +#include +#include + +namespace indicators { + +namespace details { + +template +struct if_else; + +template <> +struct if_else { + using type = std::true_type; +}; + +template <> +struct if_else { + using type = std::false_type; +}; + +template +struct if_else_type; + +template +struct if_else_type { + using type = True; +}; + +template +struct if_else_type { + using type = False; +}; + +template +struct conjuction; + +template <> +struct conjuction<> : std::true_type {}; + +template +struct conjuction + : if_else_type>::type { +}; + +template +struct disjunction; + +template <> +struct disjunction<> : std::false_type {}; + +template +struct disjunction + : if_else_type>::type {}; + +enum class ProgressBarOption { + bar_width = 0, + prefix_text, + postfix_text, + start, + end, + fill, + lead, + remainder, + max_postfix_text_len, + completed, + show_percentage, + show_elapsed_time, + show_remaining_time, + saved_start_time, + foreground_color, + spinner_show, + spinner_states, + font_styles, + hide_bar_when_complete, + min_progress, + max_progress, + progress_type, + stream +}; + +template +struct Setting { + template ::value>::type> + explicit Setting(Args&&... args) : value(std::forward(args)...) {} + Setting(const Setting&) = default; + Setting(Setting&&) = default; + + static constexpr auto id = Id; + using type = T; + + T value{}; +}; + +template +struct is_setting : std::false_type {}; + +template +struct is_setting> : std::true_type {}; + +template +struct are_settings : if_else...>::value>::type {}; + +template <> +struct are_settings<> : std::true_type {}; + +template +struct is_setting_from_tuple; + +template +struct is_setting_from_tuple> : std::true_type {}; + +template +struct is_setting_from_tuple> + : if_else...>::value>::type { +}; + +template +struct are_settings_from_tuple + : if_else< + conjuction...>::value>::type { +}; + +template +struct always_true { + static constexpr auto value = true; +}; + +template +Default&& get_impl(Default&& def) { + return std::forward(def); +} + +template +auto get_impl(Default&& /*def*/, T&& first, Args&&... /*tail*/) -> + typename std::enable_if<(std::decay::type::id == Id), + decltype(std::forward(first))>::type { + return std::forward(first); +} + +template +auto get_impl(Default&& def, T&& /*first*/, Args&&... tail) -> + typename std::enable_if< + (std::decay::type::id != Id), + decltype(get_impl(std::forward(def), + std::forward(tail)...))>::type { + return get_impl(std::forward(def), std::forward(tail)...); +} + +template ::value, void>::type> +auto get(Default&& def, Args&&... args) + -> decltype(details::get_impl(std::forward(def), + std::forward(args)...)) { + return details::get_impl(std::forward(def), + std::forward(args)...); +} + +template +using StringSetting = Setting; + +template +using IntegerSetting = Setting; + +template +using BooleanSetting = Setting; + +template +struct option_idx; + +template +struct option_idx, counter> + : if_else_type<(Id == T::id), std::integral_constant, + option_idx, counter + 1>>::type { +}; + +template +struct option_idx, counter> { + static_assert(always_true<(ProgressBarOption)Id>::value, + "No such option was found"); +}; + +template +auto get_value(Settings&& settings) + -> decltype(( + std::get::type>::value>( + std::declval()))) { + return std::get::type>::value>( + std::forward(settings)); +} + +} // namespace details + +namespace option { +using BarWidth = details::IntegerSetting; +using PrefixText = + details::StringSetting; +using PostfixText = + details::StringSetting; +using Start = details::StringSetting; +using End = details::StringSetting; +using Fill = details::StringSetting; +using Lead = details::StringSetting; +using Remainder = details::StringSetting; +using MaxPostfixTextLen = + details::IntegerSetting; +using Completed = + details::BooleanSetting; +using ShowPercentage = + details::BooleanSetting; +using ShowElapsedTime = + details::BooleanSetting; +using ShowRemainingTime = + details::BooleanSetting; +using SavedStartTime = + details::BooleanSetting; +using ForegroundColor = + details::Setting; +using ShowSpinner = + details::BooleanSetting; +using SpinnerStates = + details::Setting, + details::ProgressBarOption::spinner_states>; +using HideBarWhenComplete = + details::BooleanSetting; +using FontStyles = details::Setting, + details::ProgressBarOption::font_styles>; +using MinProgress = + details::IntegerSetting; +using MaxProgress = + details::IntegerSetting; +using ProgressType = + details::Setting; +using Stream = + details::Setting; +} // namespace option +} // namespace indicators + +#endif + +#ifndef INDICATORS_CURSOR_CONTROL +#define INDICATORS_CURSOR_CONTROL + +#if defined(_MSC_VER) +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#include +#include +#else +#include +#endif + +namespace indicators { + +#if defined(_MSC_VER) + +static inline void show_console_cursor(bool const show) { + HANDLE out = GetStdHandle(STD_OUTPUT_HANDLE); + + CONSOLE_CURSOR_INFO cursorInfo; + + GetConsoleCursorInfo(out, &cursorInfo); + cursorInfo.bVisible = show; // set the cursor visibility + SetConsoleCursorInfo(out, &cursorInfo); +} + +static inline void erase_line() { + auto hStdout = GetStdHandle(STD_OUTPUT_HANDLE); + if (!hStdout) + return; + + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(hStdout, &csbiInfo); + + COORD cursor; + + cursor.X = 0; + cursor.Y = csbiInfo.dwCursorPosition.Y; + + DWORD count = 0; + + FillConsoleOutputCharacterA(hStdout, ' ', csbiInfo.dwSize.X, cursor, &count); + + FillConsoleOutputAttribute(hStdout, csbiInfo.wAttributes, csbiInfo.dwSize.X, + cursor, &count); + + SetConsoleCursorPosition(hStdout, cursor); +} + +#else + +static inline void show_console_cursor(bool const show) { + std::fputs(show ? "\033[?25h" : "\033[?25l", stdout); +} + +static inline void erase_line() { + std::fputs("\r\033[K", stdout); +} + +#endif + +} // namespace indicators + +#endif + +#ifndef INDICATORS_CURSOR_MOVEMENT +#define INDICATORS_CURSOR_MOVEMENT + +#if defined(_MSC_VER) +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#include +#include +#else +#include +#endif + +namespace indicators { + +#ifdef _MSC_VER + +static inline void move(int x, int y) { + auto hStdout = GetStdHandle(STD_OUTPUT_HANDLE); + if (!hStdout) + return; + + CONSOLE_SCREEN_BUFFER_INFO csbiInfo; + GetConsoleScreenBufferInfo(hStdout, &csbiInfo); + + COORD cursor; + + cursor.X = csbiInfo.dwCursorPosition.X + x; + cursor.Y = csbiInfo.dwCursorPosition.Y + y; + SetConsoleCursorPosition(hStdout, cursor); +} + +static inline void move_up(int lines) { + move(0, -lines); +} +static inline void move_down(int lines) { + move(0, -lines); +} +static inline void move_right(int cols) { + move(cols, 0); +} +static inline void move_left(int cols) { + move(-cols, 0); +} + +#else + +static inline void move_up(int lines) { + std::cout << "\033[" << lines << "A"; +} +static inline void move_down(int lines) { + std::cout << "\033[" << lines << "B"; +} +static inline void move_right(int cols) { + std::cout << "\033[" << cols << "C"; +} +static inline void move_left(int cols) { + std::cout << "\033[" << cols << "D"; +} + +#endif + +} // namespace indicators + +#endif + +#ifndef INDICATORS_STREAM_HELPER +#define INDICATORS_STREAM_HELPER + +// #include +#ifndef INDICATORS_DISPLAY_WIDTH +#define INDICATORS_DISPLAY_WIDTH + +#include +#include +#include +#include +#include + +namespace unicode { + +namespace details { + +/* + * This is an implementation of wcwidth() and wcswidth() (defined in + * IEEE Std 1002.1-2001) for Unicode. + * + * http://www.opengroup.org/onlinepubs/007904975/functions/wcwidth.html + * http://www.opengroup.org/onlinepubs/007904975/functions/wcswidth.html + * + * In fixed-width output devices, Latin characters all occupy a single + * "cell" position of equal width, whereas ideographic CJK characters + * occupy two such cells. Interoperability between terminal-line + * applications and (teletype-style) character terminals using the + * UTF-8 encoding requires agreement on which character should advance + * the cursor by how many cell positions. No established formal + * standards exist at present on which Unicode character shall occupy + * how many cell positions on character terminals. These routines are + * a first attempt of defining such behavior based on simple rules + * applied to data provided by the Unicode Consortium. + * + * For some graphical characters, the Unicode standard explicitly + * defines a character-cell width via the definition of the East Asian + * FullWidth (F), Wide (W), Half-width (H), and Narrow (Na) classes. + * In all these cases, there is no ambiguity about which width a + * terminal shall use. For characters in the East Asian Ambiguous (A) + * class, the width choice depends purely on a preference of backward + * compatibility with either historic CJK or Western practice. + * Choosing single-width for these characters is easy to justify as + * the appropriate long-term solution, as the CJK practice of + * displaying these characters as double-width comes from historic + * implementation simplicity (8-bit encoded characters were displayed + * single-width and 16-bit ones double-width, even for Greek, + * Cyrillic, etc.) and not any typographic considerations. + * + * Much less clear is the choice of width for the Not East Asian + * (Neutral) class. Existing practice does not dictate a width for any + * of these characters. It would nevertheless make sense + * typographically to allocate two character cells to characters such + * as for instance EM SPACE or VOLUME INTEGRAL, which cannot be + * represented adequately with a single-width glyph. The following + * routines at present merely assign a single-cell width to all + * neutral characters, in the interest of simplicity. This is not + * entirely satisfactory and should be reconsidered before + * establishing a formal standard in this area. At the moment, the + * decision which Not East Asian (Neutral) characters should be + * represented by double-width glyphs cannot yet be answered by + * applying a simple rule from the Unicode database content. Setting + * up a proper standard for the behavior of UTF-8 character terminals + * will require a careful analysis not only of each Unicode character, + * but also of each presentation form, something the author of these + * routines has avoided to do so far. + * + * http://www.unicode.org/unicode/reports/tr11/ + * + * Markus Kuhn -- 2007-05-26 (Unicode 5.0) + * + * Permission to use, copy, modify, and distribute this software + * for any purpose and without fee is hereby granted. The author + * disclaims all warranties with regard to this software. + * + * Latest version: http://www.cl.cam.ac.uk/~mgk25/ucs/wcwidth.c + */ + +struct interval { + int first; + int last; +}; + +/* auxiliary function for binary search in interval table */ +static inline int bisearch(wchar_t ucs, const struct interval* table, int max) { + int min = 0; + int mid; + + if (ucs < table[0].first || ucs > table[max].last) + return 0; + while (max >= min) { + mid = (min + max) / 2; + if (ucs > table[mid].last) + min = mid + 1; + else if (ucs < table[mid].first) + max = mid - 1; + else + return 1; + } + + return 0; +} + +/* The following two functions define the column width of an ISO 10646 + * character as follows: + * + * - The null character (U+0000) has a column width of 0. + * + * - Other C0/C1 control characters and DEL will lead to a return + * value of -1. + * + * - Non-spacing and enclosing combining characters (general + * category code Mn or Me in the Unicode database) have a + * column width of 0. + * + * - SOFT HYPHEN (U+00AD) has a column width of 1. + * + * - Other format characters (general category code Cf in the Unicode + * database) and ZERO WIDTH SPACE (U+200B) have a column width of 0. + * + * - Hangul Jamo medial vowels and final consonants (U+1160-U+11FF) + * have a column width of 0. + * + * - Spacing characters in the East Asian Wide (W) or East Asian + * Full-width (F) category as defined in Unicode Technical + * Report #11 have a column width of 2. + * + * - All remaining characters (including all printable + * ISO 8859-1 and WGL4 characters, Unicode control characters, + * etc.) have a column width of 1. + * + * This implementation assumes that wchar_t characters are encoded + * in ISO 10646. + */ + +static inline int mk_wcwidth(wchar_t ucs) { + /* sorted list of non-overlapping intervals of non-spacing characters */ + /* generated by "uniset +cat=Me +cat=Mn +cat=Cf -00AD +1160-11FF +200B c" */ + static const struct interval combining[] = { + {0x0300, 0x036F}, {0x0483, 0x0486}, {0x0488, 0x0489}, + {0x0591, 0x05BD}, {0x05BF, 0x05BF}, {0x05C1, 0x05C2}, + {0x05C4, 0x05C5}, {0x05C7, 0x05C7}, {0x0600, 0x0603}, + {0x0610, 0x0615}, {0x064B, 0x065E}, {0x0670, 0x0670}, + {0x06D6, 0x06E4}, {0x06E7, 0x06E8}, {0x06EA, 0x06ED}, + {0x070F, 0x070F}, {0x0711, 0x0711}, {0x0730, 0x074A}, + {0x07A6, 0x07B0}, {0x07EB, 0x07F3}, {0x0901, 0x0902}, + {0x093C, 0x093C}, {0x0941, 0x0948}, {0x094D, 0x094D}, + {0x0951, 0x0954}, {0x0962, 0x0963}, {0x0981, 0x0981}, + {0x09BC, 0x09BC}, {0x09C1, 0x09C4}, {0x09CD, 0x09CD}, + {0x09E2, 0x09E3}, {0x0A01, 0x0A02}, {0x0A3C, 0x0A3C}, + {0x0A41, 0x0A42}, {0x0A47, 0x0A48}, {0x0A4B, 0x0A4D}, + {0x0A70, 0x0A71}, {0x0A81, 0x0A82}, {0x0ABC, 0x0ABC}, + {0x0AC1, 0x0AC5}, {0x0AC7, 0x0AC8}, {0x0ACD, 0x0ACD}, + {0x0AE2, 0x0AE3}, {0x0B01, 0x0B01}, {0x0B3C, 0x0B3C}, + {0x0B3F, 0x0B3F}, {0x0B41, 0x0B43}, {0x0B4D, 0x0B4D}, + {0x0B56, 0x0B56}, {0x0B82, 0x0B82}, {0x0BC0, 0x0BC0}, + {0x0BCD, 0x0BCD}, {0x0C3E, 0x0C40}, {0x0C46, 0x0C48}, + {0x0C4A, 0x0C4D}, {0x0C55, 0x0C56}, {0x0CBC, 0x0CBC}, + {0x0CBF, 0x0CBF}, {0x0CC6, 0x0CC6}, {0x0CCC, 0x0CCD}, + {0x0CE2, 0x0CE3}, {0x0D41, 0x0D43}, {0x0D4D, 0x0D4D}, + {0x0DCA, 0x0DCA}, {0x0DD2, 0x0DD4}, {0x0DD6, 0x0DD6}, + {0x0E31, 0x0E31}, {0x0E34, 0x0E3A}, {0x0E47, 0x0E4E}, + {0x0EB1, 0x0EB1}, {0x0EB4, 0x0EB9}, {0x0EBB, 0x0EBC}, + {0x0EC8, 0x0ECD}, {0x0F18, 0x0F19}, {0x0F35, 0x0F35}, + {0x0F37, 0x0F37}, {0x0F39, 0x0F39}, {0x0F71, 0x0F7E}, + {0x0F80, 0x0F84}, {0x0F86, 0x0F87}, {0x0F90, 0x0F97}, + {0x0F99, 0x0FBC}, {0x0FC6, 0x0FC6}, {0x102D, 0x1030}, + {0x1032, 0x1032}, {0x1036, 0x1037}, {0x1039, 0x1039}, + {0x1058, 0x1059}, {0x1160, 0x11FF}, {0x135F, 0x135F}, + {0x1712, 0x1714}, {0x1732, 0x1734}, {0x1752, 0x1753}, + {0x1772, 0x1773}, {0x17B4, 0x17B5}, {0x17B7, 0x17BD}, + {0x17C6, 0x17C6}, {0x17C9, 0x17D3}, {0x17DD, 0x17DD}, + {0x180B, 0x180D}, {0x18A9, 0x18A9}, {0x1920, 0x1922}, + {0x1927, 0x1928}, {0x1932, 0x1932}, {0x1939, 0x193B}, + {0x1A17, 0x1A18}, {0x1B00, 0x1B03}, {0x1B34, 0x1B34}, + {0x1B36, 0x1B3A}, {0x1B3C, 0x1B3C}, {0x1B42, 0x1B42}, + {0x1B6B, 0x1B73}, {0x1DC0, 0x1DCA}, {0x1DFE, 0x1DFF}, + {0x200B, 0x200F}, {0x202A, 0x202E}, {0x2060, 0x2063}, + {0x206A, 0x206F}, {0x20D0, 0x20EF}, {0x302A, 0x302F}, + {0x3099, 0x309A}, {0xA806, 0xA806}, {0xA80B, 0xA80B}, + {0xA825, 0xA826}, {0xFB1E, 0xFB1E}, {0xFE00, 0xFE0F}, + {0xFE20, 0xFE23}, {0xFEFF, 0xFEFF}, {0xFFF9, 0xFFFB}, + {0x10A01, 0x10A03}, {0x10A05, 0x10A06}, {0x10A0C, 0x10A0F}, + {0x10A38, 0x10A3A}, {0x10A3F, 0x10A3F}, {0x1D167, 0x1D169}, + {0x1D173, 0x1D182}, {0x1D185, 0x1D18B}, {0x1D1AA, 0x1D1AD}, + {0x1D242, 0x1D244}, {0xE0001, 0xE0001}, {0xE0020, 0xE007F}, + {0xE0100, 0xE01EF}}; + + /* test for 8-bit control characters */ + if (ucs == 0) + return 0; + if (ucs < 32 || (ucs >= 0x7f && ucs < 0xa0)) + return -1; + + /* binary search in table of non-spacing characters */ + if (bisearch(ucs, combining, sizeof(combining) / sizeof(struct interval) - 1)) + return 0; + + /* if we arrive here, ucs is not a combining or C0/C1 control character */ + + return 1 + + (ucs >= 0x1100 && + (ucs <= 0x115f || /* Hangul Jamo init. consonants */ + ucs == 0x2329 || ucs == 0x232a || + (ucs >= 0x2e80 && ucs <= 0xa4cf && ucs != 0x303f) || /* CJK ... Yi */ + (ucs >= 0xac00 && ucs <= 0xd7a3) || /* Hangul Syllables */ + (ucs >= 0xf900 && + ucs <= 0xfaff) || /* CJK Compatibility Ideographs */ + (ucs >= 0xfe10 && ucs <= 0xfe19) || /* Vertical forms */ + (ucs >= 0xfe30 && ucs <= 0xfe6f) || /* CJK Compatibility Forms */ + (ucs >= 0xff00 && ucs <= 0xff60) || /* Fullwidth Forms */ + (ucs >= 0xffe0 && ucs <= 0xffe6) || + (ucs >= 0x20000 && ucs <= 0x2fffd) || + (ucs >= 0x30000 && ucs <= 0x3fffd))); +} + +static inline int mk_wcswidth(const wchar_t* pwcs, size_t n) { + int w, width = 0; + + for (; *pwcs && n-- > 0; pwcs++) + if ((w = mk_wcwidth(*pwcs)) < 0) + return -1; + else + width += w; + + return width; +} + +/* + * The following functions are the same as mk_wcwidth() and + * mk_wcswidth(), except that spacing characters in the East Asian + * Ambiguous (A) category as defined in Unicode Technical Report #11 + * have a column width of 2. This variant might be useful for users of + * CJK legacy encodings who want to migrate to UCS without changing + * the traditional terminal character-width behaviour. It is not + * otherwise recommended for general use. + */ +static inline int mk_wcwidth_cjk(wchar_t ucs) { + /* sorted list of non-overlapping intervals of East Asian Ambiguous + * characters, generated by "uniset +WIDTH-A -cat=Me -cat=Mn -cat=Cf c" */ + static const struct interval ambiguous[] = { + {0x00A1, 0x00A1}, {0x00A4, 0x00A4}, {0x00A7, 0x00A8}, + {0x00AA, 0x00AA}, {0x00AE, 0x00AE}, {0x00B0, 0x00B4}, + {0x00B6, 0x00BA}, {0x00BC, 0x00BF}, {0x00C6, 0x00C6}, + {0x00D0, 0x00D0}, {0x00D7, 0x00D8}, {0x00DE, 0x00E1}, + {0x00E6, 0x00E6}, {0x00E8, 0x00EA}, {0x00EC, 0x00ED}, + {0x00F0, 0x00F0}, {0x00F2, 0x00F3}, {0x00F7, 0x00FA}, + {0x00FC, 0x00FC}, {0x00FE, 0x00FE}, {0x0101, 0x0101}, + {0x0111, 0x0111}, {0x0113, 0x0113}, {0x011B, 0x011B}, + {0x0126, 0x0127}, {0x012B, 0x012B}, {0x0131, 0x0133}, + {0x0138, 0x0138}, {0x013F, 0x0142}, {0x0144, 0x0144}, + {0x0148, 0x014B}, {0x014D, 0x014D}, {0x0152, 0x0153}, + {0x0166, 0x0167}, {0x016B, 0x016B}, {0x01CE, 0x01CE}, + {0x01D0, 0x01D0}, {0x01D2, 0x01D2}, {0x01D4, 0x01D4}, + {0x01D6, 0x01D6}, {0x01D8, 0x01D8}, {0x01DA, 0x01DA}, + {0x01DC, 0x01DC}, {0x0251, 0x0251}, {0x0261, 0x0261}, + {0x02C4, 0x02C4}, {0x02C7, 0x02C7}, {0x02C9, 0x02CB}, + {0x02CD, 0x02CD}, {0x02D0, 0x02D0}, {0x02D8, 0x02DB}, + {0x02DD, 0x02DD}, {0x02DF, 0x02DF}, {0x0391, 0x03A1}, + {0x03A3, 0x03A9}, {0x03B1, 0x03C1}, {0x03C3, 0x03C9}, + {0x0401, 0x0401}, {0x0410, 0x044F}, {0x0451, 0x0451}, + {0x2010, 0x2010}, {0x2013, 0x2016}, {0x2018, 0x2019}, + {0x201C, 0x201D}, {0x2020, 0x2022}, {0x2024, 0x2027}, + {0x2030, 0x2030}, {0x2032, 0x2033}, {0x2035, 0x2035}, + {0x203B, 0x203B}, {0x203E, 0x203E}, {0x2074, 0x2074}, + {0x207F, 0x207F}, {0x2081, 0x2084}, {0x20AC, 0x20AC}, + {0x2103, 0x2103}, {0x2105, 0x2105}, {0x2109, 0x2109}, + {0x2113, 0x2113}, {0x2116, 0x2116}, {0x2121, 0x2122}, + {0x2126, 0x2126}, {0x212B, 0x212B}, {0x2153, 0x2154}, + {0x215B, 0x215E}, {0x2160, 0x216B}, {0x2170, 0x2179}, + {0x2190, 0x2199}, {0x21B8, 0x21B9}, {0x21D2, 0x21D2}, + {0x21D4, 0x21D4}, {0x21E7, 0x21E7}, {0x2200, 0x2200}, + {0x2202, 0x2203}, {0x2207, 0x2208}, {0x220B, 0x220B}, + {0x220F, 0x220F}, {0x2211, 0x2211}, {0x2215, 0x2215}, + {0x221A, 0x221A}, {0x221D, 0x2220}, {0x2223, 0x2223}, + {0x2225, 0x2225}, {0x2227, 0x222C}, {0x222E, 0x222E}, + {0x2234, 0x2237}, {0x223C, 0x223D}, {0x2248, 0x2248}, + {0x224C, 0x224C}, {0x2252, 0x2252}, {0x2260, 0x2261}, + {0x2264, 0x2267}, {0x226A, 0x226B}, {0x226E, 0x226F}, + {0x2282, 0x2283}, {0x2286, 0x2287}, {0x2295, 0x2295}, + {0x2299, 0x2299}, {0x22A5, 0x22A5}, {0x22BF, 0x22BF}, + {0x2312, 0x2312}, {0x2460, 0x24E9}, {0x24EB, 0x254B}, + {0x2550, 0x2573}, {0x2580, 0x258F}, {0x2592, 0x2595}, + {0x25A0, 0x25A1}, {0x25A3, 0x25A9}, {0x25B2, 0x25B3}, + {0x25B6, 0x25B7}, {0x25BC, 0x25BD}, {0x25C0, 0x25C1}, + {0x25C6, 0x25C8}, {0x25CB, 0x25CB}, {0x25CE, 0x25D1}, + {0x25E2, 0x25E5}, {0x25EF, 0x25EF}, {0x2605, 0x2606}, + {0x2609, 0x2609}, {0x260E, 0x260F}, {0x2614, 0x2615}, + {0x261C, 0x261C}, {0x261E, 0x261E}, {0x2640, 0x2640}, + {0x2642, 0x2642}, {0x2660, 0x2661}, {0x2663, 0x2665}, + {0x2667, 0x266A}, {0x266C, 0x266D}, {0x266F, 0x266F}, + {0x273D, 0x273D}, {0x2776, 0x277F}, {0xE000, 0xF8FF}, + {0xFFFD, 0xFFFD}, {0xF0000, 0xFFFFD}, {0x100000, 0x10FFFD}}; + + /* binary search in table of non-spacing characters */ + if (bisearch(ucs, ambiguous, sizeof(ambiguous) / sizeof(struct interval) - 1)) + return 2; + + return mk_wcwidth(ucs); +} + +static inline int mk_wcswidth_cjk(const wchar_t* pwcs, size_t n) { + int w, width = 0; + + for (; *pwcs && n-- > 0; pwcs++) + if ((w = mk_wcwidth_cjk(*pwcs)) < 0) + return -1; + else + width += w; + + return width; +} + +// convert UTF-8 string to wstring +#ifdef _MSC_VER +static inline std::wstring utf8_decode(const std::string& s) { + auto r = setlocale(LC_ALL, ""); + std::string curLocale; + if (r) + curLocale = r; + const char* _Source = s.c_str(); + size_t _Dsize = std::strlen(_Source) + 1; + wchar_t* _Dest = new wchar_t[_Dsize]; + size_t _Osize; + mbstowcs_s(&_Osize, _Dest, _Dsize, _Source, _Dsize); + std::wstring result = _Dest; + delete[] _Dest; + setlocale(LC_ALL, curLocale.c_str()); + return result; +} +#else +static inline std::wstring utf8_decode(const std::string& s) { + auto r = setlocale(LC_ALL, ""); + std::string curLocale; + if (r) + curLocale = r; + const char* _Source = s.c_str(); + size_t _Dsize = mbstowcs(NULL, _Source, 0) + 1; + wchar_t* _Dest = new wchar_t[_Dsize]; + wmemset(_Dest, 0, _Dsize); + mbstowcs(_Dest, _Source, _Dsize); + std::wstring result = _Dest; + delete[] _Dest; + setlocale(LC_ALL, curLocale.c_str()); + return result; +} +#endif + +} // namespace details + +static inline int display_width(const std::string& input) { + using namespace unicode::details; + return mk_wcswidth(utf8_decode(input).c_str(), input.size()); +} + +static inline int display_width(const std::wstring& input) { + return details::mk_wcswidth(input.c_str(), input.size()); +} + +} // namespace unicode + +#endif +// #include +// #include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace indicators { +namespace details { + +inline void set_stream_color(std::ostream& os, Color color) { + switch (color) { + case Color::grey: + os << termcolor::grey; + break; + case Color::red: + os << termcolor::red; + break; + case Color::green: + os << termcolor::green; + break; + case Color::yellow: + os << termcolor::yellow; + break; + case Color::blue: + os << termcolor::blue; + break; + case Color::magenta: + os << termcolor::magenta; + break; + case Color::cyan: + os << termcolor::cyan; + break; + case Color::white: + os << termcolor::white; + break; + default: + assert(false); + } +} + +inline void set_font_style(std::ostream& os, FontStyle style) { + switch (style) { + case FontStyle::bold: + os << termcolor::bold; + break; + case FontStyle::dark: + os << termcolor::dark; + break; + case FontStyle::italic: + os << termcolor::italic; + break; + case FontStyle::underline: + os << termcolor::underline; + break; + case FontStyle::blink: + os << termcolor::blink; + break; + case FontStyle::reverse: + os << termcolor::reverse; + break; + case FontStyle::concealed: + os << termcolor::concealed; + break; + case FontStyle::crossed: + os << termcolor::crossed; + break; + default: + break; + } +} + +inline std::ostream& write_duration(std::ostream& os, + std::chrono::nanoseconds ns) { + using namespace std; + using namespace std::chrono; + using days = duration>; + char fill = os.fill(); + os.fill('0'); + auto d = duration_cast(ns); + ns -= d; + auto h = duration_cast(ns); + ns -= h; + auto m = duration_cast(ns); + ns -= m; + auto s = duration_cast(ns); + if (d.count() > 0) + os << setw(2) << d.count() << "d:"; + if (h.count() > 0) + os << setw(2) << h.count() << "h:"; + os << setw(2) << m.count() << "m:" << setw(2) << s.count() << 's'; + os.fill(fill); + return os; +} + +class BlockProgressScaleWriter { + public: + BlockProgressScaleWriter(std::ostream& os, size_t bar_width) + : os(os), bar_width(bar_width) {} + + std::ostream& write(float progress) { + std::string fill_text{"█"}; + std::vector lead_characters{" ", "▏", "▎", "▍", + "▌", "▋", "▊", "▉"}; + auto value = (std::min)(1.0f, (std::max)(0.0f, progress / 100.0f)); + auto whole_width = std::floor(value * bar_width); + auto remainder_width = fmod((value * bar_width), 1.0f); + auto part_width = std::floor(remainder_width * lead_characters.size()); + std::string lead_text = lead_characters[size_t(part_width)]; + if ((bar_width - whole_width - 1) < 0) + lead_text = ""; + for (size_t i = 0; i < whole_width; ++i) + os << fill_text; + os << lead_text; + for (size_t i = 0; i < (bar_width - whole_width - 1); ++i) + os << " "; + return os; + } + + private: + std::ostream& os; + size_t bar_width = 0; +}; + +class ProgressScaleWriter { + public: + ProgressScaleWriter(std::ostream& os, size_t bar_width, + const std::string& fill, const std::string& lead, + const std::string& remainder) + : os(os), + bar_width(bar_width), + fill(fill), + lead(lead), + remainder(remainder) {} + + std::ostream& write(float progress) { + auto pos = static_cast(progress * bar_width / 100.0); + for (size_t i = 0, current_display_width = 0; i < bar_width;) { + std::string next; + + if (i < pos) { + next = fill; + current_display_width = unicode::display_width(fill); + } else if (i == pos) { + next = lead; + current_display_width = unicode::display_width(lead); + } else { + next = remainder; + current_display_width = unicode::display_width(remainder); + } + + i += current_display_width; + + if (i > bar_width) { + // `next` is larger than the allowed bar width + // fill with empty space instead + os << std::string((bar_width - (i - current_display_width)), ' '); + break; + } + + os << next; + } + return os; + } + + private: + std::ostream& os; + size_t bar_width = 0; + std::string fill; + std::string lead; + std::string remainder; +}; + +class IndeterminateProgressScaleWriter { + public: + IndeterminateProgressScaleWriter(std::ostream& os, size_t bar_width, + const std::string& fill, + const std::string& lead) + : os(os), bar_width(bar_width), fill(fill), lead(lead) {} + + std::ostream& write(size_t progress) { + for (size_t i = 0; i < bar_width;) { + std::string next; + size_t current_display_width = 0; + + if (i < progress) { + next = fill; + current_display_width = unicode::display_width(fill); + } else if (i == progress) { + next = lead; + current_display_width = unicode::display_width(lead); + } else { + next = fill; + current_display_width = unicode::display_width(fill); + } + + i += current_display_width; + + if (i > bar_width) { + // `next` is larger than the allowed bar width + // fill with empty space instead + os << std::string((bar_width - (i - current_display_width)), ' '); + break; + } + + os << next; + } + return os; + } + + private: + std::ostream& os; + size_t bar_width = 0; + std::string fill; + std::string lead; +}; + +} // namespace details +} // namespace indicators + +#endif + +#ifndef INDICATORS_PROGRESS_BAR +#define INDICATORS_PROGRESS_BAR + +// #include + +#include +#include +#include +#include +// #include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace indicators { + +class ProgressBar { + using Settings = + std::tuple; + + public: + template ::type...>::value, + void*>::type = nullptr> + explicit ProgressBar(Args&&... args) + : settings_( + details::get( + option::BarWidth{100}, std::forward(args)...), + details::get( + option::PrefixText{}, std::forward(args)...), + details::get( + option::PostfixText{}, std::forward(args)...), + details::get( + option::Start{"["}, std::forward(args)...), + details::get( + option::End{"]"}, std::forward(args)...), + details::get( + option::Fill{"="}, std::forward(args)...), + details::get( + option::Lead{">"}, std::forward(args)...), + details::get( + option::Remainder{" "}, std::forward(args)...), + details::get( + option::MaxPostfixTextLen{0}, std::forward(args)...), + details::get( + option::Completed{false}, std::forward(args)...), + details::get( + option::ShowPercentage{false}, std::forward(args)...), + details::get( + option::ShowElapsedTime{false}, std::forward(args)...), + details::get( + option::ShowRemainingTime{false}, std::forward(args)...), + details::get( + option::SavedStartTime{false}, std::forward(args)...), + details::get( + option::ForegroundColor{Color::unspecified}, + std::forward(args)...), + details::get( + option::FontStyles{std::vector{}}, + std::forward(args)...), + details::get( + option::MinProgress{0}, std::forward(args)...), + details::get( + option::MaxProgress{100}, std::forward(args)...), + details::get( + option::ProgressType{ProgressType::incremental}, + std::forward(args)...), + details::get( + option::Stream{std::cout}, std::forward(args)...)) { + + // if progress is incremental, start from min_progress + // else start from max_progress + const auto type = get_value(); + if (type == ProgressType::incremental) + progress_ = get_value(); + else + progress_ = get_value(); + } + + template + void set_option(details::Setting&& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = std::move(setting).value; + } + + template + void set_option(const details::Setting& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = setting.value; + } + + void set_option( + const details::Setting< + std::string, details::ProgressBarOption::postfix_text>& setting) { + std::lock_guard lock(mutex_); + get_value() = setting.value; + if (setting.value.length() > + get_value()) { + get_value() = + setting.value.length(); + } + } + + void set_option( + details::Setting&& + setting) { + std::lock_guard lock(mutex_); + get_value() = + std::move(setting).value; + auto& new_value = get_value(); + if (new_value.length() > + get_value()) { + get_value() = + new_value.length(); + } + } + + void set_progress(size_t new_progress) { + { + std::lock_guard lock(mutex_); + progress_ = new_progress; + } + + save_start_time(); + print_progress(); + } + + void tick() { + { + std::lock_guard lock{mutex_}; + const auto type = get_value(); + if (type == ProgressType::incremental) + progress_ += 1; + else + progress_ -= 1; + } + save_start_time(); + print_progress(); + } + + size_t current() { + std::lock_guard lock{mutex_}; + return (std::min)( + progress_, + size_t(get_value())); + } + + bool is_completed() const { + return get_value(); + } + + void mark_as_completed() { + get_value() = true; + print_progress(); + } + + private: + template + auto get_value() + -> decltype((details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + template + auto get_value() const + -> decltype(( + details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + size_t progress_{0}; + Settings settings_; + std::chrono::nanoseconds elapsed_; + std::chrono::time_point start_time_point_; + std::mutex mutex_; + + template + friend class MultiProgress; + template + friend class DynamicProgress; + std::atomic multi_progress_mode_{false}; + + void save_start_time() { + auto& show_elapsed_time = + get_value(); + auto& saved_start_time = + get_value(); + auto& show_remaining_time = + get_value(); + if ((show_elapsed_time || show_remaining_time) && !saved_start_time) { + start_time_point_ = std::chrono::high_resolution_clock::now(); + saved_start_time = true; + } + } + + std::pair get_prefix_text() { + std::stringstream os; + os << get_value(); + const auto result = os.str(); + const auto result_size = unicode::display_width(result); + return {result, result_size}; + } + + std::pair get_postfix_text() { + std::stringstream os; + const auto max_progress = + get_value(); + + if (get_value()) { + os << " " + << (std::min)(static_cast(static_cast(progress_) / + max_progress * 100), + size_t(100)) + << "%"; + } + + auto& saved_start_time = + get_value(); + + if (get_value()) { + os << " ["; + if (saved_start_time) + details::write_duration(os, elapsed_); + else + os << "00:00s"; + } + + if (get_value()) { + if (get_value()) + os << "<"; + else + os << " ["; + + if (saved_start_time) { + auto eta = std::chrono::nanoseconds( + progress_ > 0 + ? static_cast(std::ceil(float(elapsed_.count()) * + max_progress / progress_)) + : 0); + auto remaining = eta > elapsed_ ? (eta - elapsed_) : (elapsed_ - eta); + details::write_duration(os, remaining); + } else { + os << "00:00s"; + } + + os << "]"; + } else { + if (get_value()) + os << "]"; + } + + os << " " << get_value(); + + const auto result = os.str(); + const auto result_size = unicode::display_width(result); + return {result, result_size}; + } + + public: + void print_progress(bool from_multi_progress = false) { + std::lock_guard lock{mutex_}; + + auto& os = get_value(); + + const auto type = get_value(); + const auto min_progress = + get_value(); + const auto max_progress = + get_value(); + if (multi_progress_mode_ && !from_multi_progress) { + if ((type == ProgressType::incremental && progress_ >= max_progress) || + (type == ProgressType::decremental && progress_ <= min_progress)) { + get_value() = true; + } + return; + } + auto now = std::chrono::high_resolution_clock::now(); + if (!get_value()) + elapsed_ = std::chrono::duration_cast( + now - start_time_point_); + + if (get_value() != + Color::unspecified) + details::set_stream_color( + os, get_value()); + + for (auto& style : get_value()) + details::set_font_style(os, style); + + const auto prefix_pair = get_prefix_text(); + const auto prefix_text = prefix_pair.first; + const auto prefix_length = prefix_pair.second; + os << "\r" << prefix_text; + + os << get_value(); + + details::ProgressScaleWriter writer{ + os, get_value(), + get_value(), + get_value(), + get_value()}; + writer.write(double(progress_) / double(max_progress) * 100.0f); + + os << get_value(); + + const auto postfix_pair = get_postfix_text(); + const auto postfix_text = postfix_pair.first; + const auto postfix_length = postfix_pair.second; + os << postfix_text; + + // Get length of prefix text and postfix text + const auto start_length = + get_value().size(); + const auto bar_width = get_value(); + const auto end_length = get_value().size(); + const auto terminal_width = terminal_size().second; + // prefix + bar_width + postfix should be <= terminal_width + const int remaining = + terminal_width - (prefix_length + start_length + bar_width + + end_length + postfix_length); + if (prefix_length == -1 || postfix_length == -1) { + os << "\r"; + } else if (remaining > 0) { + os << std::string(remaining, ' ') << "\r"; + } else if (remaining < 0) { + // Do nothing. Maybe in the future truncate postfix with ... + } + os.flush(); + + if ((type == ProgressType::incremental && progress_ >= max_progress) || + (type == ProgressType::decremental && progress_ <= min_progress)) { + get_value() = true; + } + if (get_value() && + !from_multi_progress) // Don't std::endl if calling from MultiProgress + os << termcolor::reset << std::endl; + } +}; + +} // namespace indicators + +#endif + +#ifndef INDICATORS_BLOCK_PROGRESS_BAR +#define INDICATORS_BLOCK_PROGRESS_BAR + +// #include +// #include + +#include +#include +#include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace indicators { + +class BlockProgressBar { + using Settings = + std::tuple; + + public: + template ::type...>::value, + void*>::type = nullptr> + explicit BlockProgressBar(Args&&... args) + : settings_( + details::get( + option::ForegroundColor{Color::unspecified}, + std::forward(args)...), + details::get( + option::BarWidth{100}, std::forward(args)...), + details::get( + option::Start{"["}, std::forward(args)...), + details::get( + option::End{"]"}, std::forward(args)...), + details::get( + option::PrefixText{""}, std::forward(args)...), + details::get( + option::PostfixText{""}, std::forward(args)...), + details::get( + option::ShowPercentage{true}, std::forward(args)...), + details::get( + option::ShowElapsedTime{false}, std::forward(args)...), + details::get( + option::ShowRemainingTime{false}, std::forward(args)...), + details::get( + option::Completed{false}, std::forward(args)...), + details::get( + option::SavedStartTime{false}, std::forward(args)...), + details::get( + option::MaxPostfixTextLen{0}, std::forward(args)...), + details::get( + option::FontStyles{std::vector{}}, + std::forward(args)...), + details::get( + option::MaxProgress{100}, std::forward(args)...), + details::get( + option::Stream{std::cout}, std::forward(args)...)) {} + + template + void set_option(details::Setting&& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = std::move(setting).value; + } + + template + void set_option(const details::Setting& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = setting.value; + } + + void set_option( + const details::Setting< + std::string, details::ProgressBarOption::postfix_text>& setting) { + std::lock_guard lock(mutex_); + get_value() = setting.value; + if (setting.value.length() > + get_value()) { + get_value() = + setting.value.length(); + } + } + + void set_option( + details::Setting&& + setting) { + std::lock_guard lock(mutex_); + get_value() = + std::move(setting).value; + auto& new_value = get_value(); + if (new_value.length() > + get_value()) { + get_value() = + new_value.length(); + } + } + + void set_progress(float value) { + { + std::lock_guard lock{mutex_}; + progress_ = value; + } + save_start_time(); + print_progress(); + } + + void tick() { + { + std::lock_guard lock{mutex_}; + progress_ += 1; + } + save_start_time(); + print_progress(); + } + + size_t current() { + std::lock_guard lock{mutex_}; + return (std::min)( + static_cast(progress_), + size_t(get_value())); + } + + bool is_completed() const { + return get_value(); + } + + void mark_as_completed() { + get_value() = true; + print_progress(); + } + + private: + template + auto get_value() + -> decltype((details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + template + auto get_value() const + -> decltype(( + details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + Settings settings_; + float progress_{0.0}; + std::chrono::time_point start_time_point_; + std::mutex mutex_; + + template + friend class MultiProgress; + template + friend class DynamicProgress; + std::atomic multi_progress_mode_{false}; + + void save_start_time() { + auto& show_elapsed_time = + get_value(); + auto& saved_start_time = + get_value(); + auto& show_remaining_time = + get_value(); + if ((show_elapsed_time || show_remaining_time) && !saved_start_time) { + start_time_point_ = std::chrono::high_resolution_clock::now(); + saved_start_time = true; + } + } + + std::pair get_prefix_text() { + std::stringstream os; + os << get_value(); + const auto result = os.str(); + const auto result_size = unicode::display_width(result); + return {result, result_size}; + } + + std::pair get_postfix_text() { + std::stringstream os; + const auto max_progress = + get_value(); + auto now = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast( + now - start_time_point_); + + if (get_value()) { + os << " " + << (std::min)(static_cast(progress_ / max_progress * 100.0), + size_t(100)) + << "%"; + } + + auto& saved_start_time = + get_value(); + + if (get_value()) { + os << " ["; + if (saved_start_time) + details::write_duration(os, elapsed); + else + os << "00:00s"; + } + + if (get_value()) { + if (get_value()) + os << "<"; + else + os << " ["; + + if (saved_start_time) { + auto eta = std::chrono::nanoseconds( + progress_ > 0 + ? static_cast(std::ceil(float(elapsed.count()) * + max_progress / progress_)) + : 0); + auto remaining = eta > elapsed ? (eta - elapsed) : (elapsed - eta); + details::write_duration(os, remaining); + } else { + os << "00:00s"; + } + + os << "]"; + } else { + if (get_value()) + os << "]"; + } + + os << " " << get_value(); + + const auto result = os.str(); + const auto result_size = unicode::display_width(result); + return {result, result_size}; + } + + public: + void print_progress(bool from_multi_progress = false) { + std::lock_guard lock{mutex_}; + + auto& os = get_value(); + + const auto max_progress = + get_value(); + if (multi_progress_mode_ && !from_multi_progress) { + if (progress_ > max_progress) { + get_value() = true; + } + return; + } + + if (get_value() != + Color::unspecified) + details::set_stream_color( + os, get_value()); + + for (auto& style : get_value()) + details::set_font_style(os, style); + + const auto prefix_pair = get_prefix_text(); + const auto prefix_text = prefix_pair.first; + const auto prefix_length = prefix_pair.second; + os << "\r" << prefix_text; + + os << get_value(); + + details::BlockProgressScaleWriter writer{ + os, get_value()}; + writer.write(progress_ / max_progress * 100); + + os << get_value(); + + const auto postfix_pair = get_postfix_text(); + const auto postfix_text = postfix_pair.first; + const auto postfix_length = postfix_pair.second; + os << postfix_text; + + // Get length of prefix text and postfix text + const auto start_length = + get_value().size(); + const auto bar_width = get_value(); + const auto end_length = get_value().size(); + const auto terminal_width = terminal_size().second; + // prefix + bar_width + postfix should be <= terminal_width + const int remaining = + terminal_width - (prefix_length + start_length + bar_width + + end_length + postfix_length); + if (prefix_length == -1 || postfix_length == -1) { + os << "\r"; + } else if (remaining > 0) { + os << std::string(remaining, ' ') << "\r"; + } else if (remaining < 0) { + // Do nothing. Maybe in the future truncate postfix with ... + } + os.flush(); + + if (progress_ > max_progress) { + get_value() = true; + } + if (get_value() && + !from_multi_progress) // Don't std::endl if calling from MultiProgress + os << termcolor::reset << std::endl; + } +}; + +} // namespace indicators + +#endif + +#ifndef INDICATORS_INDETERMINATE_PROGRESS_BAR +#define INDICATORS_INDETERMINATE_PROGRESS_BAR + +// #include + +#include +#include +#include +#include +// #include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace indicators { + +class IndeterminateProgressBar { + using Settings = + std::tuple; + + enum class Direction { forward, backward }; + + Direction direction_{Direction::forward}; + + public: + template ::type...>::value, + void*>::type = nullptr> + explicit IndeterminateProgressBar(Args&&... args) + : settings_( + details::get( + option::BarWidth{100}, std::forward(args)...), + details::get( + option::PrefixText{}, std::forward(args)...), + details::get( + option::PostfixText{}, std::forward(args)...), + details::get( + option::Start{"["}, std::forward(args)...), + details::get( + option::End{"]"}, std::forward(args)...), + details::get( + option::Fill{"."}, std::forward(args)...), + details::get( + option::Lead{"<==>"}, std::forward(args)...), + details::get( + option::MaxPostfixTextLen{0}, std::forward(args)...), + details::get( + option::Completed{false}, std::forward(args)...), + details::get( + option::ForegroundColor{Color::unspecified}, + std::forward(args)...), + details::get( + option::FontStyles{std::vector{}}, + std::forward(args)...), + details::get( + option::Stream{std::cout}, std::forward(args)...)) { + // starts with [<==>...........] + // progress_ = 0 + + // ends with [...........<==>] + // ^^^^^^^^^^^^^^^^^ bar_width + // ^^^^^^^^^^^^ (bar_width - len(lead)) + // progress_ = bar_width - len(lead) + progress_ = 0; + max_progress_ = get_value() - + get_value().size() + + get_value().size() + + get_value().size(); + } + + template + void set_option(details::Setting&& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = std::move(setting).value; + } + + template + void set_option(const details::Setting& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = setting.value; + } + + void set_option( + const details::Setting< + std::string, details::ProgressBarOption::postfix_text>& setting) { + std::lock_guard lock(mutex_); + get_value() = setting.value; + if (setting.value.length() > + get_value()) { + get_value() = + setting.value.length(); + } + } + + void set_option( + details::Setting&& + setting) { + std::lock_guard lock(mutex_); + get_value() = + std::move(setting).value; + auto& new_value = get_value(); + if (new_value.length() > + get_value()) { + get_value() = + new_value.length(); + } + } + + void tick() { + { + std::lock_guard lock{mutex_}; + if (get_value()) + return; + + progress_ += (direction_ == Direction::forward) ? 1 : -1; + if (direction_ == Direction::forward && progress_ == max_progress_) { + // time to go back + direction_ = Direction::backward; + } else if (direction_ == Direction::backward && progress_ == 0) { + direction_ = Direction::forward; + } + } + print_progress(); + } + + bool is_completed() { + return get_value(); + } + + void mark_as_completed() { + get_value() = true; + print_progress(); + } + + private: + template + auto get_value() + -> decltype((details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + template + auto get_value() const + -> decltype(( + details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + size_t progress_{0}; + size_t max_progress_; + Settings settings_; + std::chrono::nanoseconds elapsed_; + std::mutex mutex_; + + template + friend class MultiProgress; + template + friend class DynamicProgress; + std::atomic multi_progress_mode_{false}; + + std::pair get_prefix_text() { + std::stringstream os; + os << get_value(); + const auto result = os.str(); + const auto result_size = unicode::display_width(result); + return {result, result_size}; + } + + std::pair get_postfix_text() { + std::stringstream os; + os << " " << get_value(); + + const auto result = os.str(); + const auto result_size = unicode::display_width(result); + return {result, result_size}; + } + + public: + void print_progress(bool from_multi_progress = false) { + std::lock_guard lock{mutex_}; + + auto& os = get_value(); + + if (multi_progress_mode_ && !from_multi_progress) { + return; + } + if (get_value() != + Color::unspecified) + details::set_stream_color( + os, get_value()); + + for (auto& style : get_value()) + details::set_font_style(os, style); + + const auto prefix_pair = get_prefix_text(); + const auto prefix_text = prefix_pair.first; + const auto prefix_length = prefix_pair.second; + os << "\r" << prefix_text; + + os << get_value(); + + details::IndeterminateProgressScaleWriter writer{ + os, get_value(), + get_value(), + get_value()}; + writer.write(progress_); + + os << get_value(); + + const auto postfix_pair = get_postfix_text(); + const auto postfix_text = postfix_pair.first; + const auto postfix_length = postfix_pair.second; + os << postfix_text; + + // Get length of prefix text and postfix text + const auto start_length = + get_value().size(); + const auto bar_width = get_value(); + const auto end_length = get_value().size(); + const auto terminal_width = terminal_size().second; + // prefix + bar_width + postfix should be <= terminal_width + const int remaining = + terminal_width - (prefix_length + start_length + bar_width + + end_length + postfix_length); + if (prefix_length == -1 || postfix_length == -1) { + os << "\r"; + } else if (remaining > 0) { + os << std::string(remaining, ' ') << "\r"; + } else if (remaining < 0) { + // Do nothing. Maybe in the future truncate postfix with ... + } + os.flush(); + + if (get_value() && + !from_multi_progress) // Don't std::endl if calling from MultiProgress + os << termcolor::reset << std::endl; + } +}; + +} // namespace indicators + +#endif + +#ifndef INDICATORS_MULTI_PROGRESS +#define INDICATORS_MULTI_PROGRESS +#include +#include +#include +#include +#include + +// #include +// #include +// #include + +namespace indicators { + +template +class MultiProgress { + public: + template ::type> + explicit MultiProgress(Indicators&... bars) { + bars_ = {bars...}; + for (auto& bar : bars_) { + bar.get().multi_progress_mode_ = true; + } + } + + template + typename std::enable_if<(index >= 0 && index < count), void>::type + set_progress(size_t value) { + if (!bars_[index].get().is_completed()) + bars_[index].get().set_progress(value); + print_progress(); + } + + template + typename std::enable_if<(index >= 0 && index < count), void>::type + set_progress(float value) { + if (!bars_[index].get().is_completed()) + bars_[index].get().set_progress(value); + print_progress(); + } + + template + typename std::enable_if<(index >= 0 && index < count), void>::type tick() { + if (!bars_[index].get().is_completed()) + bars_[index].get().tick(); + print_progress(); + } + + template + typename std::enable_if<(index >= 0 && index < count), bool>::type + is_completed() const { + return bars_[index].get().is_completed(); + } + + private: + std::atomic started_{false}; + std::mutex mutex_; + std::vector> bars_; + + bool _all_completed() { + bool result{true}; + for (size_t i = 0; i < count; ++i) + result &= bars_[i].get().is_completed(); + return result; + } + + public: + void print_progress() { + std::lock_guard lock{mutex_}; + if (started_) + move_up(count); + for (auto& bar : bars_) { + bar.get().print_progress(true); + std::cout << "\n"; + } + std::cout << termcolor::reset; + if (!started_) + started_ = true; + } +}; + +} // namespace indicators + +#endif + +#ifndef INDICATORS_DYNAMIC_PROGRESS +#define INDICATORS_DYNAMIC_PROGRESS + +#include +#include +// #include +// #include +// #include +// #include +// #include +#include +#include +#include + +namespace indicators { + +template +class DynamicProgress { + using Settings = std::tuple; + + public: + template + explicit DynamicProgress(Indicators&... bars) { + bars_ = {bars...}; + for (auto& bar : bars_) { + bar.get().multi_progress_mode_ = true; + ++total_count_; + ++incomplete_count_; + } + } + + Indicator& operator[](size_t index) { + print_progress(); + std::lock_guard lock{mutex_}; + return bars_[index].get(); + } + + size_t push_back(Indicator& bar) { + std::lock_guard lock{mutex_}; + bar.multi_progress_mode_ = true; + bars_.push_back(bar); + return bars_.size() - 1; + } + + template + void set_option(details::Setting&& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = std::move(setting).value; + } + + template + void set_option(const details::Setting& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = setting.value; + } + + private: + Settings settings_; + std::atomic started_{false}; + std::mutex mutex_; + std::vector> bars_; + std::atomic total_count_{0}; + std::atomic incomplete_count_{0}; + + template + auto get_value() + -> decltype((details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + template + auto get_value() const + -> decltype(( + details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + public: + void print_progress() { + std::lock_guard lock{mutex_}; + auto& hide_bar_when_complete = + get_value(); + if (hide_bar_when_complete) { + // Hide completed bars + if (started_) { + for (size_t i = 0; i < incomplete_count_; ++i) { + move_up(1); + erase_line(); + std::cout << std::flush; + } + } + incomplete_count_ = 0; + for (auto& bar : bars_) { + if (!bar.get().is_completed()) { + bar.get().print_progress(true); + std::cout << "\n"; + ++incomplete_count_; + } + } + if (!started_) + started_ = true; + } else { + // Don't hide any bars + if (started_) + move_up(static_cast(total_count_)); + for (auto& bar : bars_) { + bar.get().print_progress(true); + std::cout << "\n"; + } + if (!started_) + started_ = true; + } + total_count_ = bars_.size(); + std::cout << termcolor::reset; + } +}; + +} // namespace indicators + +#endif + +#ifndef INDICATORS_PROGRESS_SPINNER +#define INDICATORS_PROGRESS_SPINNER + +// #include + +#include +#include +#include +#include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include + +namespace indicators { + +class ProgressSpinner { + using Settings = + std::tuple; + + public: + template ::type...>::value, + void*>::type = nullptr> + explicit ProgressSpinner(Args&&... args) + : settings_( + details::get( + option::ForegroundColor{Color::unspecified}, + std::forward(args)...), + details::get( + option::PrefixText{}, std::forward(args)...), + details::get( + option::PostfixText{}, std::forward(args)...), + details::get( + option::ShowPercentage{true}, std::forward(args)...), + details::get( + option::ShowElapsedTime{false}, std::forward(args)...), + details::get( + option::ShowRemainingTime{false}, std::forward(args)...), + details::get( + option::ShowSpinner{true}, std::forward(args)...), + details::get( + option::SavedStartTime{false}, std::forward(args)...), + details::get( + option::Completed{false}, std::forward(args)...), + details::get( + option::MaxPostfixTextLen{0}, std::forward(args)...), + details::get( + option::SpinnerStates{std::vector{ + "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}}, + std::forward(args)...), + details::get( + option::FontStyles{std::vector{}}, + std::forward(args)...), + details::get( + option::MaxProgress{100}, std::forward(args)...), + details::get( + option::Stream{std::cout}, std::forward(args)...)) {} + + template + void set_option(details::Setting&& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = std::move(setting).value; + } + + template + void set_option(const details::Setting& setting) { + static_assert( + !std::is_same( + std::declval()))>::type>::value, + "Setting has wrong type!"); + std::lock_guard lock(mutex_); + get_value() = setting.value; + } + + void set_option( + const details::Setting< + std::string, details::ProgressBarOption::postfix_text>& setting) { + std::lock_guard lock(mutex_); + get_value() = setting.value; + if (setting.value.length() > + get_value()) { + get_value() = + setting.value.length(); + } + } + + void set_option( + details::Setting&& + setting) { + std::lock_guard lock(mutex_); + get_value() = + std::move(setting).value; + auto& new_value = get_value(); + if (new_value.length() > + get_value()) { + get_value() = + new_value.length(); + } + } + + void set_progress(size_t value) { + { + std::lock_guard lock{mutex_}; + progress_ = value; + } + save_start_time(); + print_progress(); + } + + void tick() { + { + std::lock_guard lock{mutex_}; + progress_ += 1; + } + save_start_time(); + print_progress(); + } + + size_t current() { + std::lock_guard lock{mutex_}; + return (std::min)( + progress_, + size_t(get_value())); + } + + bool is_completed() const { + return get_value(); + } + + void mark_as_completed() { + get_value() = true; + print_progress(); + } + + private: + Settings settings_; + size_t progress_{0}; + size_t index_{0}; + std::chrono::time_point start_time_point_; + std::mutex mutex_; + + template + auto get_value() + -> decltype((details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + template + auto get_value() const + -> decltype(( + details::get_value(std::declval()).value)) { + return details::get_value(settings_).value; + } + + void save_start_time() { + auto& show_elapsed_time = + get_value(); + auto& show_remaining_time = + get_value(); + auto& saved_start_time = + get_value(); + if ((show_elapsed_time || show_remaining_time) && !saved_start_time) { + start_time_point_ = std::chrono::high_resolution_clock::now(); + saved_start_time = true; + } + } + + public: + void print_progress() { + std::lock_guard lock{mutex_}; + + auto& os = get_value(); + + const auto max_progress = + get_value(); + auto now = std::chrono::high_resolution_clock::now(); + auto elapsed = std::chrono::duration_cast( + now - start_time_point_); + + if (get_value() != + Color::unspecified) + details::set_stream_color( + os, get_value()); + + for (auto& style : get_value()) + details::set_font_style(os, style); + + os << get_value(); + if (get_value()) + os << get_value() + [index_ % + get_value().size()]; + if (get_value()) { + os << " " << std::size_t(progress_ / double(max_progress) * 100) << "%"; + } + + if (get_value()) { + os << " ["; + details::write_duration(os, elapsed); + } + + if (get_value()) { + if (get_value()) + os << "<"; + else + os << " ["; + auto eta = std::chrono::nanoseconds( + progress_ > 0 + ? static_cast(std::ceil(float(elapsed.count()) * + max_progress / progress_)) + : 0); + auto remaining = eta > elapsed ? (eta - elapsed) : (elapsed - eta); + details::write_duration(os, remaining); + os << "]"; + } else { + if (get_value()) + os << "]"; + } + + if (get_value() == 0) + get_value() = 10; + os << " " << get_value() + << std::string( + get_value(), + ' ') + << "\r"; + os.flush(); + index_ += 1; + if (progress_ > max_progress) { + get_value() = true; + } + if (get_value()) + os << termcolor::reset << std::endl; + } +}; + +} // namespace indicators + +#endif diff --git a/engine/common/download_task.h b/engine/common/download_task.h index 5994cdaed..39bf03a99 100644 --- a/engine/common/download_task.h +++ b/engine/common/download_task.h @@ -5,6 +5,7 @@ #include #include #include +#include enum class DownloadType { Model, Engine, Miscellaneous, CudaToolkit, Cortex }; @@ -55,6 +56,22 @@ inline std::string DownloadTypeToString(DownloadType type) { } } +inline DownloadType DownloadTypeFromString(const std::string& str) { + if (str == "Model") { + return DownloadType::Model; + } else if (str == "Engine") { + return DownloadType::Engine; + } else if (str == "Miscellaneous") { + return DownloadType::Miscellaneous; + } else if (str == "CudaToolkit") { + return DownloadType::CudaToolkit; + } else if (str == "Cortex") { + return DownloadType::Cortex; + } else { + return DownloadType::Miscellaneous; + } +} + struct DownloadTask { enum class Status { Pending, InProgress, Completed, Cancelled, Error }; @@ -116,3 +133,52 @@ struct DownloadTask { {"id", id}, {"type", DownloadTypeToString(type)}, {"items", dl_items}}; } }; + +namespace common { +inline DownloadItem GetDownloadItemFromJson(const Json::Value item_json) { + DownloadItem item; + if (!item_json["id"].isNull()) { + item.id = item_json["id"].asString(); + } + if (!item_json["downloadUrl"].isNull()) { + item.downloadUrl = item_json["downloadUrl"].asString(); + } + + if (!item_json["localPath"].isNull()) { + item.localPath = std::filesystem::path(item_json["localPath"].asString()); + } + + if (!item_json["checksum"].isNull()) { + item.checksum = item_json["checksum"].asString(); + } + + if (!item_json["bytes"].isNull()) { + item.bytes = item_json["bytes"].asUInt64(); + } + + if (!item_json["downloadedBytes"].isNull()) { + item.downloadedBytes = item_json["downloadedBytes"].asUInt64(); + } + + return item; +} + +inline DownloadTask GetDownloadTaskFromJson(const Json::Value item_json) { + DownloadTask task; + + if (!item_json["id"].isNull()) { + task.id = item_json["id"].asString(); + } + + if (!item_json["type"].isNull()) { + task.type = DownloadTypeFromString(item_json["type"].asString()); + } + + if (!item_json["items"].isNull() && item_json["items"].isArray()) { + for (auto const& i_json : item_json["items"]) { + task.items.emplace_back(GetDownloadItemFromJson(i_json)); + } + } + return task; +} +} // namespace common \ No newline at end of file diff --git a/engine/common/event.h b/engine/common/event.h index fe68bd04e..c23ebea5f 100644 --- a/engine/common/event.h +++ b/engine/common/event.h @@ -45,6 +45,22 @@ std::string DownloadEventTypeToString(DownloadEventType type) { return "Unknown"; } } + +inline DownloadEventType DownloadEventTypeFromString(const std::string& str) { + if (str == "DownloadStarted") { + return DownloadEventType::DownloadStarted; + } else if (str == "DownloadStopped") { + return DownloadEventType::DownloadStopped; + } else if (str == "DownloadUpdated") { + return DownloadEventType::DownloadUpdated; + } else if (str == "DownloadSuccess") { + return DownloadEventType::DownloadSuccess; + } else if (str == "DownloadError") { + return DownloadEventType::DownloadError; + } else { + return DownloadEventType::DownloadError; + } +} } // namespace struct DownloadEvent : public cortex::event::Event { @@ -57,6 +73,18 @@ struct DownloadEvent : public cortex::event::Event { DownloadEventType type_; DownloadTask download_task_; }; + +inline DownloadEvent GetDownloadEventFromJson(const Json::Value& item_json) { + DownloadEvent ev; + if (!item_json["type"].isNull()) { + ev.type_ = DownloadEventTypeFromString(item_json["type"].asString()); + } + + if (!item_json["task"].isNull()) { + ev.download_task_ = common::GetDownloadTaskFromJson(item_json["task"]); + } + return ev; +} } // namespace cortex::event constexpr std::size_t eventMaxSize = diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index ae3316c12..d4e373812 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -383,7 +383,7 @@ ModelService::DownloadModelFromCortexsoAsync( if (model_entry.has_value()) { return cpp::fail("Please delete the model before downloading again"); } - auto on_finished = [&, unique_model_id](const DownloadTask& finishedTask) { + auto on_finished = [unique_model_id, branch](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; diff --git a/engine/test/components/test_event.cc b/engine/test/components/test_event.cc new file mode 100644 index 000000000..d10933f52 --- /dev/null +++ b/engine/test/components/test_event.cc @@ -0,0 +1,50 @@ + +#include "common/event.h" +#include "gtest/gtest.h" +#include "utils/json_helper.h" + +class EventTest : public ::testing::Test {}; + +TEST_F(EventTest, EventFromString) { + // clang-format off + std::string ev_str = R"({ + "task": { + "id": "tinyllama:gguf", + "items": [ + { + "bytes": 668788096, + "checksum": "N/A", + "downloadUrl": "https://huggingface.co/cortexso/tinyllama/resolve/gguf/model.gguf", + "downloadedBytes": 0, + "id": "model.gguf", + "localPath": + "/home/jan/cortexcpp/models/cortex.so/tinyllama/gguf/model.gguf" + }, + { + "bytes": 545, + "checksum": "N/A", + "downloadUrl": "https://huggingface.co/cortexso/tinyllama/resolve/gguf/model.yml", + "downloadedBytes": 0, + "id": "model.yml", + "localPath": + "/home/jan/cortexcpp/models/cortex.so/tinyllama/gguf/model.yml" + } + ], + "type": "Model" + }, + "type": "DownloadStarted" + })"; + // clang-format on + auto root = json_helper::ParseJsonString(ev_str); + std::cout << root.toStyledString() << std::endl; + + auto download_item = common::GetDownloadItemFromJson(root["task"]["items"][0]); + EXPECT_EQ(download_item.downloadUrl, root["task"]["items"][0]["downloadUrl"].asString()); + std::cout << download_item.ToString() << std::endl; + + auto download_task = common::GetDownloadTaskFromJson(root["task"]); + std::cout << download_task.ToString() << std::endl; + + auto ev = cortex::event::GetDownloadEventFromJson(root); + EXPECT_EQ(ev.type_, cortex::event::DownloadEventType::DownloadStarted); +} \ No newline at end of file From 0c12e5ca82f3c4bc3ef874d6584a40b7e7143faa Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 24 Oct 2024 16:29:21 +0700 Subject: [PATCH 02/24] fix: more --- engine/cli/CMakeLists.txt | 1 + engine/cli/command_line_parser.cc | 4 +- engine/cli/commands/engine_install_cmd.cc | 78 ++++++++- engine/cli/commands/engine_install_cmd.h | 8 +- engine/cli/commands/model_pull_cmd.cc | 192 ++++++++++++++-------- engine/cli/commands/run_cmd.cc | 6 +- engine/cli/commands/run_cmd.h | 3 + engine/cli/utils/download_manager.cc | 106 ++++++++++++ engine/cli/utils/download_manager.h | 25 +++ engine/services/engine_service.cc | 43 ++--- engine/services/engine_service.h | 12 +- 11 files changed, 364 insertions(+), 114 deletions(-) create mode 100644 engine/cli/utils/download_manager.cc create mode 100644 engine/cli/utils/download_manager.h diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 947bd9347..3dd9ff27c 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -81,6 +81,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_manager.cc ) target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) diff --git a/engine/cli/command_line_parser.cc b/engine/cli/command_line_parser.cc index faed315f8..59973b021 100644 --- a/engine/cli/command_line_parser.cc +++ b/engine/cli/command_line_parser.cc @@ -464,7 +464,9 @@ void CommandLineParser::EngineInstall(CLI::App* parent, if (std::exchange(executed_, true)) return; try { - commands::EngineInstallCmd(download_service_) + commands::EngineInstallCmd(download_service_, + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort)) .Exec(engine_name, version, src); } catch (const std::exception& e) { CTL_ERR(e.what()); diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 8cf7c1cc7..a4164d52d 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -1,16 +1,82 @@ #include "engine_install_cmd.h" +#include "server_start_cmd.h" +#include "utils/download_manager.h" +#include "utils/engine_constants.h" #include "utils/logging_utils.h" +#include "utils/json_helper.h" namespace commands { +namespace { +std::string NormalizeEngine(const std::string& engine) { + if (engine == kLlamaEngine) { + return kLlamaRepo; + } else if (engine == kOnnxEngine) { + return kOnnxRepo; + } else if (engine == kTrtLlmEngine) { + return kTrtLlmRepo; + } + return engine; +}; +} // namespace -void EngineInstallCmd::Exec(const std::string& engine, +bool EngineInstallCmd::Exec(const std::string& engine, const std::string& version, const std::string& src) { - auto result = engine_service_.InstallEngine(engine, version, src); - if (result.has_error()) { - CLI_LOG(result.error()); - } else if(result && result.value()){ - CLI_LOG("Engine " << engine << " installed successfully!"); + // Handle local install, if fails, fallback to remote install + if (!src.empty()) { + auto res = engine_service_.UnzipEngine(engine, version, src); + if(res.has_error()) { + CLI_LOG(res.error()); + return false; + } + if(res.value()) { + CLI_LOG("Engine " << engine << " installed successfully!"); + return true; + } + } + + // Start server if server is not started yet + if (!commands::IsServerAlive(host_, port_)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host_, port_)) { + return false; + } + } + + httplib::Client cli(host_ + ":" + std::to_string(port_)); + Json::Value json_data; + auto data_str = json_data.toStyledString(); + cli.set_read_timeout(std::chrono::seconds(60)); + auto res = cli.Post("/v1/engines/install/" + engine, httplib::Headers(), + data_str.data(), data_str.size(), "application/json"); + + if (res) { + if (res->status == httplib::StatusCode::OK_200) { + } else { + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); + return false; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return false; } + + CLI_LOG("Pulling ...") + DownloadManager dm; + dm.Connect(host_, port_); + if (!dm.Handle(NormalizeEngine(engine))) + return false; + + bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); + if (check_cuda_download) { + if (!dm.Handle("cuda")) + return false; + } + + CLI_LOG("Engine " << engine << " downloaded successfully!") + return true; } }; // namespace commands diff --git a/engine/cli/commands/engine_install_cmd.h b/engine/cli/commands/engine_install_cmd.h index 199d4d319..4a22d03f7 100644 --- a/engine/cli/commands/engine_install_cmd.h +++ b/engine/cli/commands/engine_install_cmd.h @@ -7,13 +7,15 @@ namespace commands { class EngineInstallCmd { public: - explicit EngineInstallCmd(std::shared_ptr download_service) - : engine_service_{EngineService(download_service)} {}; + explicit EngineInstallCmd(std::shared_ptr download_service, const std::string& host, int port) + : engine_service_{EngineService(download_service)}, host_(host), port_(port) {}; - void Exec(const std::string& engine, const std::string& version = "latest", + bool Exec(const std::string& engine, const std::string& version = "latest", const std::string& src = ""); private: EngineService engine_service_; + std::string host_; + int port_; }; } // namespace commands diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 386380397..42d784a2a 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -3,14 +3,129 @@ #include "cli/utils/easywsclient.hpp" #include "cli/utils/indicators.hpp" #include "common/event.h" +#include "database/models.h" #include "server_start_cmd.h" +#include "utils/cli_selection_utils.h" +#include "utils/download_manager.h" #include "utils/format_utils.h" +#include "utils/huggingface_utils.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" +#include "utils/string_utils.h" namespace commands { +namespace { +// TODO(sang) request from Server +cpp::result GetModelId(const std::string& input) { + if (input.empty()) { + return cpp::fail( + "Input must be Cortex Model Hub handle or HuggingFace url!"); + } + + if (string_utils::StartsWith(input, "https://")) { + return input; + } + + if (input.find(":") != std::string::npos) { + auto parsed = string_utils::SplitBy(input, ":"); + if (parsed.size() != 2) { + return cpp::fail("Invalid model handle: " + input); + } + return input; + } + + if (input.find("/") != std::string::npos) { + auto parsed = string_utils::SplitBy(input, "/"); + if (parsed.size() != 2) { + return cpp::fail("Invalid model handle: " + input); + } + + auto author = parsed[0]; + auto model_name = parsed[1]; + if (author == "cortexso") { + return author + ":" + model_name; + } + + auto repo_info = + huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); + + if (!repo_info.has_value()) { + return cpp::fail("Model not found"); + } + + if (!repo_info->gguf.has_value()) { + return cpp::fail( + "Not a GGUF model. Currently, only GGUF single file is " + "supported."); + } + + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + options.push_back(sibling.rfilename); + } + } + auto selection = cli_selection_utils::PrintSelection(options); + std::cout << "Selected: " << selection.value() << std::endl; + + return huggingface_utils::GetDownloadableUrl(author, model_name, + selection.value()); + } + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", input); + if (branches.has_error()) { + return cpp::fail(branches.error()); + } + + auto default_model_branch = huggingface_utils::GetDefaultBranch(input); + + cortex::db::Models modellist_handler; + auto downloaded_model_ids = + modellist_handler.FindRelatedModel(input).value_or( + std::vector{}); + + std::vector avai_download_opts{}; + for (const auto& branch : branches.value()) { + if (branch.second.name == "main") { // main branch only have metadata. skip + continue; + } + auto model_id = input + ":" + branch.second.name; + if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(), + model_id) != + downloaded_model_ids.end()) { // if downloaded, we skip it + continue; + } + avai_download_opts.emplace_back(model_id); + } + + if (avai_download_opts.empty()) { + // TODO: only with pull, we return + return cpp::fail("No variant available"); + } + std::optional normalized_def_branch = std::nullopt; + if (default_model_branch.has_value()) { + normalized_def_branch = input + ":" + default_model_branch.value(); + } + string_utils::SortStrings(downloaded_model_ids); + string_utils::SortStrings(avai_download_opts); + auto selection = cli_selection_utils::PrintModelSelection( + downloaded_model_ids, avai_download_opts, normalized_def_branch); + if (!selection.has_value()) { + return cpp::fail("Invalid selection"); + } + return selection.value(); +} + +} // namespace void ModelPullCmd::Exec(const std::string& host, int port, const std::string& input) { + auto r = GetModelId(input); + if (r.has_error()) { + CLI_LOG(r.error()); + return; + } + auto const& model_id = r.value(); + CTL_INF(model_id); // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); @@ -22,7 +137,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, httplib::Client cli(host + ":" + std::to_string(port)); Json::Value json_data; - json_data["model"] = input; + json_data["model"] = model_id; auto data_str = json_data.toStyledString(); cli.set_read_timeout(std::chrono::seconds(60)); auto res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), @@ -32,7 +147,8 @@ void ModelPullCmd::Exec(const std::string& host, int port, if (res->status == httplib::StatusCode::OK_200) { // CLI_LOG("OK"); } else { - CTL_ERR("Error:"); + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); return; } } else { @@ -41,72 +157,12 @@ void ModelPullCmd::Exec(const std::string& host, int port, return; } - std::unique_ptr> - bars; - - std::vector> items; - - auto handle_message = [&bars, &items](const std::string& message) { - // std::cout << message << std::endl; - - auto pad_string = [](const std::string& str, - size_t max_length = 20) -> std::string { - // Check the length of the input string - if (str.length() >= max_length) { - return str.substr( - 0, max_length); // Return truncated string if it's too long - } + CLI_LOG("Pulling ...") + DownloadManager dm; + dm.Connect(host, port); + if (!dm.Handle(model_id)) + return; - // Calculate the number of spaces needed - size_t padding_size = max_length - str.length(); - - // Create a new string with the original string followed by spaces - return str + std::string(padding_size, ' '); - }; - - auto ev = cortex::event::GetDownloadEventFromJson( - json_helper::ParseJsonString(message)); - // std::cout << downloaded << " " << total << std::endl; - if (!bars) { - bars = std::make_unique< - indicators::DynamicProgress>(); - for (auto& i : ev.download_task_.items) { - items.emplace_back(std::make_unique( - indicators::option::BarWidth{50}, indicators::option::Start{"|"}, - // indicators::option::Fill{"■"}, indicators::option::Lead{"■"}, - // indicators::option::Remainder{" "}, - indicators::option::End{"|"}, indicators::option::PrefixText{pad_string(i.id)}, - indicators::option::PostfixText{"Downloading files"}, - indicators::option::ForegroundColor{indicators::Color::white}, - indicators::option::ShowRemainingTime{true}, - indicators::option::FontStyles{std::vector{ - indicators::FontStyle::bold}})); - bars->push_back(*(items.back())); - } - } else { - for (int i = 0; i < ev.download_task_.items.size(); i++) { - auto& it = ev.download_task_.items[i]; - uint64_t downloaded = it.downloadedBytes.value_or(0); - uint64_t total = it.bytes.value_or(9999); - (*bars)[i].set_progress(static_cast(downloaded) / total * 100); - (*bars)[i].set_option(indicators::option::PostfixText{ - format_utils::BytesToHumanReadable(downloaded) + "/" + - format_utils::BytesToHumanReadable(total)}); - } - } - }; - - auto ws = easywsclient::WebSocket::from_url("ws://" + host + ":" + - std::to_string(port) + "/events"); - // auto result = model_service_.DownloadModel(input); - // if (result.has_error()) { - // CLI_LOG(result.error()); - // } - while (ws->getReadyState() != easywsclient::WebSocket::CLOSED) { - ws->poll(); - ws->dispatch(handle_message); - } - std::cout << "Done" << std::endl; - delete ws; + CLI_LOG("Model " << model_id << " downloaded successfully!") } }; // namespace commands diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index c80f12de1..13e3be4e7 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -8,6 +8,7 @@ #include "server_start_cmd.h" #include "utils/cli_selection_utils.h" #include "utils/logging_utils.h" +#include "engine_install_cmd.h" namespace commands { @@ -114,9 +115,8 @@ void RunCmd::Exec(bool run_detach) { throw std::runtime_error("Engine " + mc.engine + " is incompatible"); } if (required_engine.value().status == EngineService::kNotInstalled) { - auto install_engine_result = engine_service_.InstallEngine(mc.engine); - if (install_engine_result.has_error()) { - throw std::runtime_error(install_engine_result.error()); + if(!EngineInstallCmd(download_service_, host_, port_).Exec(mc.engine)) { + return; } } } diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index 4a0d68078..7d3e60054 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -16,6 +16,7 @@ class RunCmd { : host_{std::move(host)}, port_{port}, model_handle_{std::move(model_handle)}, + download_service_(download_service), engine_service_{EngineService(download_service)}, model_service_{ModelService(download_service)} {}; @@ -26,7 +27,9 @@ class RunCmd { int port_; std::string model_handle_; +std::shared_ptr download_service_; ModelService model_service_; EngineService engine_service_; + }; } // namespace commands diff --git a/engine/cli/utils/download_manager.cc b/engine/cli/utils/download_manager.cc new file mode 100644 index 000000000..915f63348 --- /dev/null +++ b/engine/cli/utils/download_manager.cc @@ -0,0 +1,106 @@ +#include "download_manager.h" +#include +#include "common/event.h" +#include "indicators.hpp" +#include "utils/format_utils.h" +#include "utils/json_helper.h" +#include "utils/logging_utils.h" + +bool DownloadManager::Connect(const std::string& host, int port) { + if (ws_) { + CTL_INF("Already connected!"); + return true; + } + ws_.reset(easywsclient::WebSocket::from_url( + "ws://" + host + ":" + std::to_string(port) + "/events")); + if (!!ws_) + return false; + + return true; +} + +bool DownloadManager::Handle(const std::string& id) { + assert(!!ws_); + status_ = DownloadStatus::DownloadStarted; + std::unique_ptr> + bars; + + std::vector> items; + auto handle_message = [this, &bars, &items, + id](const std::string& message) { + + CTL_INF(message); + + auto pad_string = [](const std::string& str, + size_t max_length = 20) -> std::string { + // Check the length of the input string + if (str.length() >= max_length) { + return str.substr( + 0, max_length); // Return truncated string if it's too long + } + + // Calculate the number of spaces needed + size_t padding_size = max_length - str.length(); + + // Create a new string with the original string followed by spaces + return str + std::string(padding_size, ' '); + }; + + auto ev = cortex::event::GetDownloadEventFromJson( + json_helper::ParseJsonString(message)); + // Ignore other task ids + if (ev.download_task_.id != id) { + + return; + } + + status_ = ev.type_; + // std::cout << downloaded << " " << total << std::endl; + if (!bars) { + bars = std::make_unique< + indicators::DynamicProgress>(); + for (auto& i : ev.download_task_.items) { + items.emplace_back(std::make_unique( + indicators::option::BarWidth{50}, indicators::option::Start{"|"}, + // indicators::option::Fill{"■"}, indicators::option::Lead{"■"}, + // indicators::option::Remainder{" "}, + indicators::option::End{"|"}, + indicators::option::PrefixText{pad_string(i.id)}, + // indicators::option::PostfixText{"Downloading files"}, + indicators::option::ForegroundColor{indicators::Color::white}, + indicators::option::ShowRemainingTime{true} + // indicators::option::FontStyles{std::vector{ + // indicators::FontStyle::bold}} + )); + bars->push_back(*(items.back())); + } + } else { + for (int i = 0; i < ev.download_task_.items.size(); i++) { + auto& it = ev.download_task_.items[i]; + uint64_t downloaded = it.downloadedBytes.value_or(0); + uint64_t total = it.bytes.value_or(9999); + if (status_ == DownloadStatus::DownloadUpdated) { + (*bars)[i].set_progress(static_cast(downloaded) / total * + 100); + (*bars)[i].set_option(indicators::option::PostfixText{ + format_utils::BytesToHumanReadable(downloaded) + "/" + + format_utils::BytesToHumanReadable(total)}); + } else if (status_ == DownloadStatus::DownloadSuccess) { + (*bars)[i].set_progress(100); + (*bars)[i].set_option(indicators::option::PostfixText{ + format_utils::BytesToHumanReadable(total) + "/" + + format_utils::BytesToHumanReadable(total)}); + } + } + } + }; + + while (ws_->getReadyState() != easywsclient::WebSocket::CLOSED && + !should_stop()) { + ws_->poll(); + ws_->dispatch(handle_message); + } + if (status_ == DownloadStatus::DownloadError) + return false; + return true; +} \ No newline at end of file diff --git a/engine/cli/utils/download_manager.h b/engine/cli/utils/download_manager.h new file mode 100644 index 000000000..59cdecef2 --- /dev/null +++ b/engine/cli/utils/download_manager.h @@ -0,0 +1,25 @@ +#pragma once +#include +#include +#include +#include "common/event.h" +#include "easywsclient.hpp" + +using DownloadStatus = cortex::event::DownloadEventType; +class DownloadManager { + public: + bool Connect(const std::string& host, int port); + + bool Handle(const std::string& id); + + private: + bool should_stop() const { + return status_ != DownloadStatus::DownloadStarted && + status_ != DownloadStatus::DownloadUpdated; + } + + private: + // TODO(sang) open multiple sockets + std::unique_ptr ws_; + std::atomic status_ = DownloadStatus::DownloadStarted; +}; \ No newline at end of file diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 4dfe7fefb..1f1219409 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -121,36 +121,23 @@ std::vector EngineService::GetEngineInfoList() const { return engines; } -cpp::result EngineService::InstallEngine( - const std::string& engine, const std::string& version, - const std::string& src) { - auto ne = NormalizeEngine(engine); - if (!src.empty()) { - return UnzipEngine(ne, version, src); - } else { - auto result = DownloadEngine(ne, version); - if (result.has_error()) { - return result; - } - return DownloadCuda(ne); - } -} - cpp::result EngineService::InstallEngineAsync( const std::string& engine, const std::string& version, const std::string& src) { // Although this function is called async, only download tasks are performed async - // TODO(sang) better handler for unzip and download scenarios auto ne = NormalizeEngine(engine); if (!src.empty()) { - return UnzipEngine(ne, version, src); - } else { - auto result = DownloadEngine(ne, version, true /*async*/); - if (result.has_error()) { - return result; + auto res = UnzipEngine(ne, version, src); + // If has error or engine is installed successfully + if (res.has_error() || res.value()) { + return res; } - return DownloadCuda(ne, true /*async*/); } + auto result = DownloadEngine(ne, version, true /*async*/); + if (result.has_error()) { + return result; + } + return DownloadCuda(ne, true /*async*/); } cpp::result EngineService::UnzipEngine( @@ -198,12 +185,16 @@ cpp::result EngineService::UnzipEngine( auto matched_variant = GetMatchedVariant(engine, variants); CTL_INF("Matched variant: " << matched_variant); + if (!found_cuda || matched_variant.empty()) { + return false; + } + if (matched_variant.empty()) { CTL_INF("No variant found for " << hw_inf_.sys_inf->os << "-" << hw_inf_.sys_inf->arch << ", will get engine from remote"); // Go with the remote flow - return DownloadEngine(engine, version); + // return DownloadEngine(engine, version); } else { auto engine_path = file_manager_utils::GetEnginesContainerPath(); archive_utils::ExtractArchive(path + "/" + matched_variant, @@ -211,9 +202,9 @@ cpp::result EngineService::UnzipEngine( } // Not match any cuda binary, download from remote - if (!found_cuda) { - return DownloadCuda(engine); - } + // if (!found_cuda) { + // return DownloadCuda(engine); + // } return true; } diff --git a/engine/services/engine_service.h b/engine/services/engine_service.h index 521771325..0f491edc7 100644 --- a/engine/services/engine_service.h +++ b/engine/services/engine_service.h @@ -43,25 +43,23 @@ class EngineService { std::vector GetEngineInfoList() const; - cpp::result InstallEngine( - const std::string& engine, const std::string& version = "latest", - const std::string& src = ""); - cpp::result InstallEngineAsync( const std::string& engine, const std::string& version = "latest", const std::string& src = ""); cpp::result UninstallEngine(const std::string& engine); - private: cpp::result UnzipEngine(const std::string& engine, const std::string& version, const std::string& path); + private: cpp::result DownloadEngine( - const std::string& engine, const std::string& version = "latest", bool async = false); + const std::string& engine, const std::string& version = "latest", + bool async = false); - cpp::result DownloadCuda(const std::string& engine, bool async = false); + cpp::result DownloadCuda(const std::string& engine, + bool async = false); std::string GetMatchedVariant(const std::string& engine, const std::vector& variants); From 9547b45f0fad09cf89156cbfd7ba7abc8e7160fb Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 25 Oct 2024 10:16:27 +0700 Subject: [PATCH 03/24] fix: pull models info from server --- engine/cli/commands/model_pull_cmd.cc | 71 +++++++++++++--- engine/controllers/models.cc | 39 +++++++++ engine/controllers/models.h | 3 + engine/services/model_service.cc | 116 +++++++++++++++++++++++++- engine/services/model_service.h | 11 +++ 5 files changed, 228 insertions(+), 12 deletions(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 42d784a2a..cec4d26a7 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -119,13 +119,8 @@ cpp::result GetModelId(const std::string& input) { } // namespace void ModelPullCmd::Exec(const std::string& host, int port, const std::string& input) { - auto r = GetModelId(input); - if (r.has_error()) { - CLI_LOG(r.error()); - return; - } - auto const& model_id = r.value(); - CTL_INF(model_id); + std::string model_id = input; + // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { CLI_LOG("Starting server ..."); @@ -136,16 +131,72 @@ void ModelPullCmd::Exec(const std::string& host, int port, } httplib::Client cli(host + ":" + std::to_string(port)); + cli.set_read_timeout(std::chrono::seconds(60)); + Json::Value j_data; + j_data["model"] = input; + auto d_str = j_data.toStyledString(); + auto res = cli.Post("/models/pull/info", httplib::Headers(), d_str.data(), + d_str.size(), "application/json"); + + if (res) { + if (res->status == httplib::StatusCode::OK_200) { + // CLI_LOG(res->body); + auto root = json_helper::ParseJsonString(res->body); + std::string id = root["id"].asString(); + bool is_cortexso = root["cortexso"].asBool(); + std::string default_branch = root["defaultBranch"].asString(); + std::vector downloaded; + for (auto const& v : root["downloadedModels"]) { + downloaded.push_back(v.asString()); + } + std::vector avails; + for (auto const& v : root["availableModels"]) { + avails.push_back(v.asString()); + } + + if (downloaded.empty() && avails.empty()) { + + } else { + if (is_cortexso) { + auto selection = cli_selection_utils::PrintModelSelection( + downloaded, avails, + default_branch.empty() + ? std::nullopt + : std::optional(default_branch)); + + if (!selection.has_value()) { + CLI_LOG("Invalid selection"); + return; + } + model_id = selection.value(); + + } else { + auto selection = cli_selection_utils::PrintSelection(avails); + std::cout << "Selected: " << selection.value() << std::endl; + model_id = selection.value(); + } + } + } else { + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); + return; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return; + } + Json::Value json_data; json_data["model"] = model_id; auto data_str = json_data.toStyledString(); cli.set_read_timeout(std::chrono::seconds(60)); - auto res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), - data_str.size(), "application/json"); + res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), + data_str.size(), "application/json"); if (res) { if (res->status == httplib::StatusCode::OK_200) { - // CLI_LOG("OK"); + } else { auto root = json_helper::ParseJsonString(res->body); CLI_LOG(root["message"].asString()); diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 174c89184..43ded0ef9 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -65,6 +65,45 @@ void Models::PullModel(const HttpRequestPtr& req, } } +void Models::GetModelPullInfo( + const HttpRequestPtr& req, + std::function&& callback) const { + if (!http_util::HasFieldInReq(req, callback, "model")) { + return; + } + + auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); + std::cout << model_handle << std::endl; + auto res = model_service_->GetModelPullInfo(model_handle); + if (res.has_error()) { + Json::Value ret; + ret["message"] = res.error(); + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } else { + auto const& info = res.value(); + Json::Value ret; + Json::Value downloaded(Json::arrayValue); + for (auto const& s : info.downloaded_models) { + downloaded.append(s); + } + Json::Value avails(Json::arrayValue); + for (auto const& s : info.available_models) { + avails.append(s); + } + ret["id"] = info.id; + ret["cortexso"] = info.cortexso; + ret["defaultBranch"] = info.default_branch; + ret["message"] = "Get model pull information successfully"; + ret["downloadedModels"] = downloaded; + ret["availableModels"] = avails; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k200OK); + callback(resp); + } +} + void Models::AbortPullModel( const HttpRequestPtr& req, std::function&& callback) { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index cacec2e48..b48a0d1aa 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -11,6 +11,7 @@ class Models : public drogon::HttpController { public: METHOD_LIST_BEGIN METHOD_ADD(Models::PullModel, "/pull", Post); + METHOD_ADD(Models::GetModelPullInfo, "/pull/info", Post); METHOD_ADD(Models::AbortPullModel, "/pull", Delete); METHOD_ADD(Models::ListModel, "", Get); METHOD_ADD(Models::GetModel, "/{1}", Get); @@ -39,6 +40,8 @@ class Models : public drogon::HttpController { void PullModel(const HttpRequestPtr& req, std::function&& callback); + void GetModelPullInfo(const HttpRequestPtr& req, + std::function&& callback) const; void AbortPullModel(const HttpRequestPtr& req, std::function&& callback); void ListModel(const HttpRequestPtr& req, diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index d4e373812..36b8203c3 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -15,7 +15,6 @@ #include "utils/logging_utils.h" #include "utils/result.hpp" #include "utils/string_utils.h" -#include "utils/json_helper.h" namespace { void ParseGguf(const DownloadItem& ggufDownloadItem, @@ -383,7 +382,8 @@ ModelService::DownloadModelFromCortexsoAsync( if (model_entry.has_value()) { return cpp::fail("Please delete the model before downloading again"); } - auto on_finished = [unique_model_id, branch](const DownloadTask& finishedTask) { + auto on_finished = [unique_model_id, + branch](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -723,6 +723,118 @@ cpp::result ModelService::GetModelStatus( } } +cpp::result ModelService::GetModelPullInfo( + const std::string& input) { + if (input.empty()) { + return cpp::fail( + "Input must be Cortex Model Hub handle or HuggingFace url!"); + } + + if (string_utils::StartsWith(input, "https://")) { + return ModelPullInfo{ + .id = input, .downloaded_models = {}, .available_models = {}}; + } + + if (input.find(":") != std::string::npos) { + auto parsed = string_utils::SplitBy(input, ":"); + if (parsed.size() != 2) { + return cpp::fail("Invalid model handle: " + input); + } + return ModelPullInfo{ + .id = input, .downloaded_models = {}, .available_models = {}}; + } + + if (input.find("/") != std::string::npos) { + auto parsed = string_utils::SplitBy(input, "/"); + if (parsed.size() != 2) { + return cpp::fail("Invalid model handle: " + input); + } + + auto author = parsed[0]; + auto model_name = parsed[1]; + if (author == "cortexso") { + return ModelPullInfo{ + .id = author + ":" + model_name, + .downloaded_models = {}, + .available_models = {}, + .cortexso = true, + }; + } + + auto repo_info = + huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); + + if (!repo_info.has_value()) { + return cpp::fail("Model not found"); + } + + if (!repo_info->gguf.has_value()) { + return cpp::fail( + "Not a GGUF model. Currently, only GGUF single file is " + "supported."); + } + + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + options.push_back(sibling.rfilename); + } + } + // auto selection = cli_selection_utils::PrintSelection(options); + // std::cout << "Selected: " << selection.value() << std::endl; + + return ModelPullInfo{ + .id = input, .downloaded_models = {}, .available_models = options}; + } + auto branches = + huggingface_utils::GetModelRepositoryBranches("cortexso", input); + if (branches.has_error()) { + return cpp::fail(branches.error()); + } + + auto default_model_branch = huggingface_utils::GetDefaultBranch(input); + + cortex::db::Models modellist_handler; + auto downloaded_model_ids = + modellist_handler.FindRelatedModel(input).value_or( + std::vector{}); + + std::vector avai_download_opts{}; + for (const auto& branch : branches.value()) { + if (branch.second.name == "main") { // main branch only have metadata. skip + continue; + } + auto model_id = input + ":" + branch.second.name; + if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(), + model_id) != + downloaded_model_ids.end()) { // if downloaded, we skip it + continue; + } + avai_download_opts.emplace_back(model_id); + } + + if (avai_download_opts.empty()) { + // TODO: only with pull, we return + return cpp::fail("No variant available"); + } + std::optional normalized_def_branch = std::nullopt; + if (default_model_branch.has_value()) { + normalized_def_branch = input + ":" + default_model_branch.value(); + } + string_utils::SortStrings(downloaded_model_ids); + string_utils::SortStrings(avai_download_opts); + // auto selection = cli_selection_utils::PrintModelSelection( + // downloaded_model_ids, avai_download_opts, normalized_def_branch); + // if (!selection.has_value()) { + // return cpp::fail("Invalid selection"); + // } + return ModelPullInfo{.id = input, + .default_branch = normalized_def_branch.value_or(""), + .downloaded_models = downloaded_model_ids, + .available_models = avai_download_opts, + .cortexso = true}; +} + cpp::result ModelService::AbortDownloadModel( const std::string& task_id) { return download_service_->StopTask(task_id); diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 5adc5a01e..eef949937 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -7,6 +7,14 @@ #include "services/download_service.h" #include "services/inference_service.h" +struct ModelPullInfo { + std::string id; + std::string default_branch; + std::vector downloaded_models; + std::vector available_models; + bool cortexso = false; +}; + class ModelService { public: constexpr auto static kHuggingFaceHost = "huggingface.co"; @@ -54,6 +62,9 @@ class ModelService { cpp::result GetModelStatus( const std::string& host, int port, const std::string& model_handle); + cpp::result GetModelPullInfo( + const std::string& model_handle); + cpp::result HandleUrl(const std::string& url); cpp::result HandleDownloadUrlAsync( From 808ff246856426c2b9486397d2c99c6fee354610 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 08:05:45 +0700 Subject: [PATCH 04/24] fix: model_source --- engine/cli/commands/model_pull_cmd.cc | 2 +- engine/controllers/models.cc | 2 +- engine/services/model_service.cc | 10 +++------- engine/services/model_service.h | 2 +- 4 files changed, 6 insertions(+), 10 deletions(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index cec4d26a7..8235cd1b5 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -143,7 +143,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, // CLI_LOG(res->body); auto root = json_helper::ParseJsonString(res->body); std::string id = root["id"].asString(); - bool is_cortexso = root["cortexso"].asBool(); + bool is_cortexso = root["modelSource"].asString() == "cortexso"; std::string default_branch = root["defaultBranch"].asString(); std::vector downloaded; for (auto const& v : root["downloadedModels"]) { diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 025abf4a6..77c6f3423 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -93,7 +93,7 @@ void Models::GetModelPullInfo( avails.append(s); } ret["id"] = info.id; - ret["cortexso"] = info.cortexso; + ret["modelSource"] = info.model_source; ret["defaultBranch"] = info.default_branch; ret["message"] = "Get model pull information successfully"; ret["downloadedModels"] = downloaded; diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index aa86ac925..4dbd1d873 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -785,7 +785,7 @@ cpp::result ModelService::GetModelPullInfo( .id = author + ":" + model_name, .downloaded_models = {}, .available_models = {}, - .cortexso = true, + .model_source = "cortexso", }; } @@ -851,16 +851,12 @@ cpp::result ModelService::GetModelPullInfo( } string_utils::SortStrings(downloaded_model_ids); string_utils::SortStrings(avai_download_opts); - // auto selection = cli_selection_utils::PrintModelSelection( - // downloaded_model_ids, avai_download_opts, normalized_def_branch); - // if (!selection.has_value()) { - // return cpp::fail("Invalid selection"); - // } + return ModelPullInfo{.id = input, .default_branch = normalized_def_branch.value_or(""), .downloaded_models = downloaded_model_ids, .available_models = avai_download_opts, - .cortexso = true}; + .model_source = "cortexso"}; } cpp::result ModelService::AbortDownloadModel( diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 9137bf9ce..6da76fd9f 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -12,7 +12,7 @@ struct ModelPullInfo { std::string default_branch; std::vector downloaded_models; std::vector available_models; - bool cortexso = false; + std::string model_source; }; struct StartParameterOverride { From 9405117e4a06f80bbb9fe9eb3d6da2d76a315568 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 08:08:22 +0700 Subject: [PATCH 05/24] fix: rename --- engine/cli/CMakeLists.txt | 2 +- engine/cli/commands/engine_install_cmd.cc | 4 ++-- engine/cli/commands/model_pull_cmd.cc | 4 ++-- .../cli/utils/{download_manager.cc => download_progress.cc} | 6 +++--- .../cli/utils/{download_manager.h => download_progress.h} | 3 +-- 5 files changed, 9 insertions(+), 10 deletions(-) rename engine/cli/utils/{download_manager.cc => download_progress.cc} (95%) rename engine/cli/utils/{download_manager.h => download_progress.h} (90%) diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index 3dd9ff27c..da234132d 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -81,7 +81,7 @@ add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/model_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/../services/inference_service.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/easywsclient.cc - ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_manager.cc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/download_progress.cc ) target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index a4164d52d..98d42c67b 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -1,6 +1,6 @@ #include "engine_install_cmd.h" #include "server_start_cmd.h" -#include "utils/download_manager.h" +#include "utils/download_progress.h" #include "utils/engine_constants.h" #include "utils/logging_utils.h" #include "utils/json_helper.h" @@ -65,7 +65,7 @@ bool EngineInstallCmd::Exec(const std::string& engine, } CLI_LOG("Pulling ...") - DownloadManager dm; + DownloadProgress dm; dm.Connect(host_, port_); if (!dm.Handle(NormalizeEngine(engine))) return false; diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 8235cd1b5..1d516a0ee 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -6,7 +6,7 @@ #include "database/models.h" #include "server_start_cmd.h" #include "utils/cli_selection_utils.h" -#include "utils/download_manager.h" +#include "utils/download_progress.h" #include "utils/format_utils.h" #include "utils/huggingface_utils.h" #include "utils/json_helper.h" @@ -209,7 +209,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, } CLI_LOG("Pulling ...") - DownloadManager dm; + DownloadProgress dm; dm.Connect(host, port); if (!dm.Handle(model_id)) return; diff --git a/engine/cli/utils/download_manager.cc b/engine/cli/utils/download_progress.cc similarity index 95% rename from engine/cli/utils/download_manager.cc rename to engine/cli/utils/download_progress.cc index 915f63348..79f70e5d2 100644 --- a/engine/cli/utils/download_manager.cc +++ b/engine/cli/utils/download_progress.cc @@ -1,4 +1,4 @@ -#include "download_manager.h" +#include "download_progress.h" #include #include "common/event.h" #include "indicators.hpp" @@ -6,7 +6,7 @@ #include "utils/json_helper.h" #include "utils/logging_utils.h" -bool DownloadManager::Connect(const std::string& host, int port) { +bool DownloadProgress::Connect(const std::string& host, int port) { if (ws_) { CTL_INF("Already connected!"); return true; @@ -19,7 +19,7 @@ bool DownloadManager::Connect(const std::string& host, int port) { return true; } -bool DownloadManager::Handle(const std::string& id) { +bool DownloadProgress::Handle(const std::string& id) { assert(!!ws_); status_ = DownloadStatus::DownloadStarted; std::unique_ptr> diff --git a/engine/cli/utils/download_manager.h b/engine/cli/utils/download_progress.h similarity index 90% rename from engine/cli/utils/download_manager.h rename to engine/cli/utils/download_progress.h index 59cdecef2..6511b9537 100644 --- a/engine/cli/utils/download_manager.h +++ b/engine/cli/utils/download_progress.h @@ -6,7 +6,7 @@ #include "easywsclient.hpp" using DownloadStatus = cortex::event::DownloadEventType; -class DownloadManager { +class DownloadProgress { public: bool Connect(const std::string& host, int port); @@ -19,7 +19,6 @@ class DownloadManager { } private: - // TODO(sang) open multiple sockets std::unique_ptr ws_; std::atomic status_ = DownloadStatus::DownloadStarted; }; \ No newline at end of file From 53d85de2b2998e87ee38bdd55184788bff03eaef Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 08:18:08 +0700 Subject: [PATCH 06/24] fix: download cortexso --- engine/services/model_service.cc | 62 ++++++++++++++------------------ 1 file changed, 27 insertions(+), 35 deletions(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 4dbd1d873..28284a7f0 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -757,6 +757,7 @@ cpp::result ModelService::GetModelPullInfo( return cpp::fail( "Input must be Cortex Model Hub handle or HuggingFace url!"); } + auto model_name = input; if (string_utils::StartsWith(input, "https://")) { return ModelPullInfo{ @@ -779,52 +780,43 @@ cpp::result ModelService::GetModelPullInfo( } auto author = parsed[0]; - auto model_name = parsed[1]; - if (author == "cortexso") { - return ModelPullInfo{ - .id = author + ":" + model_name, - .downloaded_models = {}, - .available_models = {}, - .model_source = "cortexso", - }; - } + model_name = parsed[1]; + if (author != "cortexso") { + auto repo_info = + huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); - auto repo_info = - huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); - - if (!repo_info.has_value()) { - return cpp::fail("Model not found"); - } + if (!repo_info.has_value()) { + return cpp::fail("Model not found"); + } - if (!repo_info->gguf.has_value()) { - return cpp::fail( - "Not a GGUF model. Currently, only GGUF single file is " - "supported."); - } + if (!repo_info->gguf.has_value()) { + return cpp::fail( + "Not a GGUF model. Currently, only GGUF single file is " + "supported."); + } - std::vector options{}; - for (const auto& sibling : repo_info->siblings) { - if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { - options.push_back(sibling.rfilename); + std::vector options{}; + for (const auto& sibling : repo_info->siblings) { + if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { + options.push_back(sibling.rfilename); + } } - } - // auto selection = cli_selection_utils::PrintSelection(options); - // std::cout << "Selected: " << selection.value() << std::endl; - return ModelPullInfo{ - .id = input, .downloaded_models = {}, .available_models = options}; + return ModelPullInfo{ + .id = input, .downloaded_models = {}, .available_models = options}; + } } auto branches = - huggingface_utils::GetModelRepositoryBranches("cortexso", input); + huggingface_utils::GetModelRepositoryBranches("cortexso", model_name); if (branches.has_error()) { return cpp::fail(branches.error()); } - auto default_model_branch = huggingface_utils::GetDefaultBranch(input); + auto default_model_branch = huggingface_utils::GetDefaultBranch(model_name); cortex::db::Models modellist_handler; auto downloaded_model_ids = - modellist_handler.FindRelatedModel(input).value_or( + modellist_handler.FindRelatedModel(model_name).value_or( std::vector{}); std::vector avai_download_opts{}; @@ -832,7 +824,7 @@ cpp::result ModelService::GetModelPullInfo( if (branch.second.name == "main") { // main branch only have metadata. skip continue; } - auto model_id = input + ":" + branch.second.name; + auto model_id = model_name + ":" + branch.second.name; if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(), model_id) != downloaded_model_ids.end()) { // if downloaded, we skip it @@ -847,12 +839,12 @@ cpp::result ModelService::GetModelPullInfo( } std::optional normalized_def_branch = std::nullopt; if (default_model_branch.has_value()) { - normalized_def_branch = input + ":" + default_model_branch.value(); + normalized_def_branch = model_name + ":" + default_model_branch.value(); } string_utils::SortStrings(downloaded_model_ids); string_utils::SortStrings(avai_download_opts); - return ModelPullInfo{.id = input, + return ModelPullInfo{.id = model_name, .default_branch = normalized_def_branch.value_or(""), .downloaded_models = downloaded_model_ids, .available_models = avai_download_opts, From c43d9729c454098eb347680f1424449dc398287b Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 09:28:34 +0700 Subject: [PATCH 07/24] fix: pull models --- engine/cli/commands/model_pull_cmd.cc | 119 +++----------------------- engine/controllers/models.cc | 1 + engine/services/model_service.cc | 43 +++++++--- engine/services/model_service.h | 1 + 4 files changed, 46 insertions(+), 118 deletions(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 1d516a0ee..f697bd0da 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -14,112 +14,13 @@ #include "utils/string_utils.h" namespace commands { -namespace { -// TODO(sang) request from Server -cpp::result GetModelId(const std::string& input) { - if (input.empty()) { - return cpp::fail( - "Input must be Cortex Model Hub handle or HuggingFace url!"); - } - - if (string_utils::StartsWith(input, "https://")) { - return input; - } - - if (input.find(":") != std::string::npos) { - auto parsed = string_utils::SplitBy(input, ":"); - if (parsed.size() != 2) { - return cpp::fail("Invalid model handle: " + input); - } - return input; - } - - if (input.find("/") != std::string::npos) { - auto parsed = string_utils::SplitBy(input, "/"); - if (parsed.size() != 2) { - return cpp::fail("Invalid model handle: " + input); - } - - auto author = parsed[0]; - auto model_name = parsed[1]; - if (author == "cortexso") { - return author + ":" + model_name; - } - - auto repo_info = - huggingface_utils::GetHuggingFaceModelRepoInfo(author, model_name); - - if (!repo_info.has_value()) { - return cpp::fail("Model not found"); - } - - if (!repo_info->gguf.has_value()) { - return cpp::fail( - "Not a GGUF model. Currently, only GGUF single file is " - "supported."); - } - - std::vector options{}; - for (const auto& sibling : repo_info->siblings) { - if (string_utils::EndsWith(sibling.rfilename, ".gguf")) { - options.push_back(sibling.rfilename); - } - } - auto selection = cli_selection_utils::PrintSelection(options); - std::cout << "Selected: " << selection.value() << std::endl; - return huggingface_utils::GetDownloadableUrl(author, model_name, - selection.value()); - } - auto branches = - huggingface_utils::GetModelRepositoryBranches("cortexso", input); - if (branches.has_error()) { - return cpp::fail(branches.error()); - } - - auto default_model_branch = huggingface_utils::GetDefaultBranch(input); - - cortex::db::Models modellist_handler; - auto downloaded_model_ids = - modellist_handler.FindRelatedModel(input).value_or( - std::vector{}); - - std::vector avai_download_opts{}; - for (const auto& branch : branches.value()) { - if (branch.second.name == "main") { // main branch only have metadata. skip - continue; - } - auto model_id = input + ":" + branch.second.name; - if (std::find(downloaded_model_ids.begin(), downloaded_model_ids.end(), - model_id) != - downloaded_model_ids.end()) { // if downloaded, we skip it - continue; - } - avai_download_opts.emplace_back(model_id); - } - - if (avai_download_opts.empty()) { - // TODO: only with pull, we return - return cpp::fail("No variant available"); - } - std::optional normalized_def_branch = std::nullopt; - if (default_model_branch.has_value()) { - normalized_def_branch = input + ":" + default_model_branch.value(); - } - string_utils::SortStrings(downloaded_model_ids); - string_utils::SortStrings(avai_download_opts); - auto selection = cli_selection_utils::PrintModelSelection( - downloaded_model_ids, avai_download_opts, normalized_def_branch); - if (!selection.has_value()) { - return cpp::fail("Invalid selection"); - } - return selection.value(); -} - -} // namespace void ModelPullCmd::Exec(const std::string& host, int port, const std::string& input) { + // model_id: use to check the download progress + // model: use as a parameter for pull API std::string model_id = input; + std::string model = input; // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { @@ -130,6 +31,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, } } + // Get model info from Server httplib::Client cli(host + ":" + std::to_string(port)); cli.set_read_timeout(std::chrono::seconds(60)); Json::Value j_data; @@ -153,9 +55,11 @@ void ModelPullCmd::Exec(const std::string& host, int port, for (auto const& v : root["availableModels"]) { avails.push_back(v.asString()); } + std::string download_url = root["downloadUrl"].asString(); if (downloaded.empty() && avails.empty()) { - + model_id = id; + model = download_url; } else { if (is_cortexso) { auto selection = cli_selection_utils::PrintModelSelection( @@ -169,11 +73,12 @@ void ModelPullCmd::Exec(const std::string& host, int port, return; } model_id = selection.value(); - + model = model_id; } else { auto selection = cli_selection_utils::PrintSelection(avails); - std::cout << "Selected: " << selection.value() << std::endl; - model_id = selection.value(); + CLI_LOG("Selected: " << selection.value()); + model_id = id + ":" + selection.value(); + model = download_url + selection.value(); } } } else { @@ -188,7 +93,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, } Json::Value json_data; - json_data["model"] = model_id; + json_data["model"] = model; auto data_str = json_data.toStyledString(); cli.set_read_timeout(std::chrono::seconds(60)); res = cli.Post("/v1/models/pull", httplib::Headers(), data_str.data(), diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index 77c6f3423..df921e1d2 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -98,6 +98,7 @@ void Models::GetModelPullInfo( ret["message"] = "Get model pull information successfully"; ret["downloadedModels"] = downloaded; ret["availableModels"] = avails; + ret["downloadUrl"] = info.download_url; auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); resp->setStatusCode(k200OK); callback(resp); diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 28284a7f0..285237bd7 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -237,7 +237,7 @@ cpp::result ModelService::HandleDownloadUrlAsync( auto file_name{url_obj.pathParams.back()}; if (author == "cortexso") { - return DownloadModelFromCortexsoAsync(model_id); + return DownloadModelFromCortexsoAsync(model_id, url_obj.pathParams[3]); } if (url_obj.pathParams.size() < 5) { @@ -280,7 +280,7 @@ cpp::result ModelService::HandleDownloadUrlAsync( .localPath = local_path, }}}}; - auto on_finished = [&](const DownloadTask& finishedTask) { + auto on_finished = [author](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; ParseGguf(gguf_download_item, author); }; @@ -345,7 +345,7 @@ cpp::result ModelService::HandleUrl( .localPath = local_path, }}}}; - auto on_finished = [&](const DownloadTask& finishedTask) { + auto on_finished = [author](const DownloadTask& finishedTask) { auto gguf_download_item = finishedTask.items[0]; ParseGguf(gguf_download_item, author); }; @@ -437,7 +437,7 @@ cpp::result ModelService::DownloadModelFromCortexso( } std::string model_id{name + ":" + branch}; - auto on_finished = [&, model_id](const DownloadTask& finishedTask) { + auto on_finished = [branch, model_id](const DownloadTask& finishedTask) { const DownloadItem* model_yml_item = nullptr; auto need_parse_gguf = true; @@ -760,8 +760,26 @@ cpp::result ModelService::GetModelPullInfo( auto model_name = input; if (string_utils::StartsWith(input, "https://")) { - return ModelPullInfo{ - .id = input, .downloaded_models = {}, .available_models = {}}; + auto url_obj = url_parser::FromUrlString(input); + + if (url_obj.host == kHuggingFaceHost) { + if (url_obj.pathParams[2] == "blob") { + url_obj.pathParams[2] = "resolve"; + } + } + auto author{url_obj.pathParams[0]}; + auto model_id{url_obj.pathParams[1]}; + auto file_name{url_obj.pathParams.back()}; + if (author == "cortexso") { + return ModelPullInfo{.id = model_id + ":" + url_obj.pathParams[3], + .downloaded_models = {}, + .available_models = {}, + .download_url = url_parser::FromUrl(url_obj)}; + } + return ModelPullInfo{.id = author + ":" + model_id + ":" + file_name, + .downloaded_models = {}, + .available_models = {}, + .download_url = url_parser::FromUrl(url_obj)}; } if (input.find(":") != std::string::npos) { @@ -770,7 +788,7 @@ cpp::result ModelService::GetModelPullInfo( return cpp::fail("Invalid model handle: " + input); } return ModelPullInfo{ - .id = input, .downloaded_models = {}, .available_models = {}}; + .id = input, .downloaded_models = {}, .available_models = {}, .download_url = input}; } if (input.find("/") != std::string::npos) { @@ -803,7 +821,11 @@ cpp::result ModelService::GetModelPullInfo( } return ModelPullInfo{ - .id = input, .downloaded_models = {}, .available_models = options}; + .id = author + ":" + model_name, + .downloaded_models = {}, + .available_models = options, + .download_url = + huggingface_utils::GetDownloadableUrl(author, model_name, "")}; } } auto branches = @@ -815,9 +837,8 @@ cpp::result ModelService::GetModelPullInfo( auto default_model_branch = huggingface_utils::GetDefaultBranch(model_name); cortex::db::Models modellist_handler; - auto downloaded_model_ids = - modellist_handler.FindRelatedModel(model_name).value_or( - std::vector{}); + auto downloaded_model_ids = modellist_handler.FindRelatedModel(model_name) + .value_or(std::vector{}); std::vector avai_download_opts{}; for (const auto& branch : branches.value()) { diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 6da76fd9f..495685982 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -13,6 +13,7 @@ struct ModelPullInfo { std::vector downloaded_models; std::vector available_models; std::string model_source; + std::string download_url; }; struct StartParameterOverride { From a601ad8ecd57b72790e5478e358d2b704c7812c1 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 09:30:38 +0700 Subject: [PATCH 08/24] fix: remove comments --- engine/services/engine_service.cc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 1f1219409..5e706be27 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -194,18 +194,12 @@ cpp::result EngineService::UnzipEngine( << hw_inf_.sys_inf->arch << ", will get engine from remote"); // Go with the remote flow - // return DownloadEngine(engine, version); } else { auto engine_path = file_manager_utils::GetEnginesContainerPath(); archive_utils::ExtractArchive(path + "/" + matched_variant, engine_path.string()); } - // Not match any cuda binary, download from remote - // if (!found_cuda) { - // return DownloadCuda(engine); - // } - return true; } From b762a19f68b2757f0d41bf33bbd2b04300c6acfb Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 09:40:34 +0700 Subject: [PATCH 09/24] fix: rename --- engine/cli/commands/engine_install_cmd.cc | 8 ++++---- engine/cli/commands/model_pull_cmd.cc | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 98d42c67b..4b1b9c792 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -65,14 +65,14 @@ bool EngineInstallCmd::Exec(const std::string& engine, } CLI_LOG("Pulling ...") - DownloadProgress dm; - dm.Connect(host_, port_); - if (!dm.Handle(NormalizeEngine(engine))) + DownloadProgress dp; + dp.Connect(host_, port_); + if (!dp.Handle(NormalizeEngine(engine))) return false; bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); if (check_cuda_download) { - if (!dm.Handle("cuda")) + if (!dp.Handle("cuda")) return false; } diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index f697bd0da..d9201fa86 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -114,9 +114,9 @@ void ModelPullCmd::Exec(const std::string& host, int port, } CLI_LOG("Pulling ...") - DownloadProgress dm; - dm.Connect(host, port); - if (!dm.Handle(model_id)) + DownloadProgress dp; + dp.Connect(host, port); + if (!dp.Handle(model_id)) return; CLI_LOG("Model " << model_id << " downloaded successfully!") From 84653ae0a1fac485d377d67aff3a649058c93c37 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 09:48:55 +0700 Subject: [PATCH 10/24] fix: change download UI --- engine/cli/utils/download_progress.cc | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index 79f70e5d2..4c1aaf3b1 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -22,13 +22,10 @@ bool DownloadProgress::Connect(const std::string& host, int port) { bool DownloadProgress::Handle(const std::string& id) { assert(!!ws_); status_ = DownloadStatus::DownloadStarted; - std::unique_ptr> - bars; - - std::vector> items; - auto handle_message = [this, &bars, &items, - id](const std::string& message) { + std::unique_ptr> bars; + std::vector> items; + auto handle_message = [this, &bars, &items, id](const std::string& message) { CTL_INF(message); auto pad_string = [](const std::string& str, @@ -58,13 +55,14 @@ bool DownloadProgress::Handle(const std::string& id) { // std::cout << downloaded << " " << total << std::endl; if (!bars) { bars = std::make_unique< - indicators::DynamicProgress>(); + indicators::DynamicProgress>(); for (auto& i : ev.download_task_.items) { - items.emplace_back(std::make_unique( - indicators::option::BarWidth{50}, indicators::option::Start{"|"}, + items.emplace_back(std::make_unique( + indicators::option::BarWidth{50}, indicators::option::Start{"["}, // indicators::option::Fill{"■"}, indicators::option::Lead{"■"}, + indicators::option::Fill{"="}, indicators::option::Lead{">"}, // indicators::option::Remainder{" "}, - indicators::option::End{"|"}, + indicators::option::End{"]"}, indicators::option::PrefixText{pad_string(i.id)}, // indicators::option::PostfixText{"Downloading files"}, indicators::option::ForegroundColor{indicators::Color::white}, @@ -82,11 +80,17 @@ bool DownloadProgress::Handle(const std::string& id) { if (status_ == DownloadStatus::DownloadUpdated) { (*bars)[i].set_progress(static_cast(downloaded) / total * 100); + (*bars)[i].set_option(indicators::option::PrefixText{ + pad_string(it.id) + + std::to_string(int(static_cast(downloaded) / total * 100)) + + '%'}); (*bars)[i].set_option(indicators::option::PostfixText{ format_utils::BytesToHumanReadable(downloaded) + "/" + format_utils::BytesToHumanReadable(total)}); } else if (status_ == DownloadStatus::DownloadSuccess) { (*bars)[i].set_progress(100); + (*bars)[i].set_option( + indicators::option::PrefixText{pad_string(it.id) + "100%"}); (*bars)[i].set_option(indicators::option::PostfixText{ format_utils::BytesToHumanReadable(total) + "/" + format_utils::BytesToHumanReadable(total)}); From 0146cc26af6ffa8ac4d7b08d4eb59f3807a37900 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 09:50:00 +0700 Subject: [PATCH 11/24] fix: comment out --- engine/controllers/models.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index df921e1d2..602c81ab6 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -73,7 +73,6 @@ void Models::GetModelPullInfo( } auto model_handle = (*(req->getJsonObject())).get("model", "").asString(); - std::cout << model_handle << std::endl; auto res = model_service_->GetModelPullInfo(model_handle); if (res.has_error()) { Json::Value ret; From 303e70fe25ccf610a03ce220638cb8a5b84d9136 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 13:11:29 +0700 Subject: [PATCH 12/24] fix: e2e tests --- engine/cli/commands/engine_install_cmd.cc | 2 +- engine/cli/commands/model_pull_cmd.cc | 2 +- engine/cli/utils/download_progress.cc | 19 +++++++++------ engine/e2e-test/test_cli_engine_install.py | 23 +++++++++++++++---- ..._cli_model_pull_cortexso_with_selection.py | 13 +++++++++++ .../test_cli_model_pull_direct_url.py | 18 +++++++++++++-- .../test_cli_model_pull_from_cortexso.py | 11 +++++++++ ..._cli_model_pull_hugging_face_repository.py | 11 +++++++++ engine/e2e-test/test_runner.py | 4 ++-- 9 files changed, 85 insertions(+), 18 deletions(-) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 4b1b9c792..27017f1bf 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -64,7 +64,7 @@ bool EngineInstallCmd::Exec(const std::string& engine, return false; } - CLI_LOG("Pulling ...") + CLI_LOG("Start downloading ...") DownloadProgress dp; dp.Connect(host_, port_); if (!dp.Handle(NormalizeEngine(engine))) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index d9201fa86..cbc62cb44 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -113,7 +113,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, return; } - CLI_LOG("Pulling ...") + CLI_LOG("Start downloading ...") DownloadProgress dp; dp.Connect(host, port); if (!dp.Handle(model_id)) diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index 4c1aaf3b1..0bc3b2926 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -25,6 +25,7 @@ bool DownloadProgress::Handle(const std::string& id) { std::unique_ptr> bars; std::vector> items; + indicators::show_console_cursor(false); auto handle_message = [this, &bars, &items, id](const std::string& message) { CTL_INF(message); @@ -78,22 +79,25 @@ bool DownloadProgress::Handle(const std::string& id) { uint64_t downloaded = it.downloadedBytes.value_or(0); uint64_t total = it.bytes.value_or(9999); if (status_ == DownloadStatus::DownloadUpdated) { - (*bars)[i].set_progress(static_cast(downloaded) / total * - 100); (*bars)[i].set_option(indicators::option::PrefixText{ pad_string(it.id) + - std::to_string(int(static_cast(downloaded) / total * 100)) + + std::to_string( + int(static_cast(downloaded) / total * 100)) + '%'}); + (*bars)[i].set_progress( + int(static_cast(downloaded) / total * 100)); (*bars)[i].set_option(indicators::option::PostfixText{ format_utils::BytesToHumanReadable(downloaded) + "/" + format_utils::BytesToHumanReadable(total)}); } else if (status_ == DownloadStatus::DownloadSuccess) { - (*bars)[i].set_progress(100); (*bars)[i].set_option( indicators::option::PrefixText{pad_string(it.id) + "100%"}); - (*bars)[i].set_option(indicators::option::PostfixText{ - format_utils::BytesToHumanReadable(total) + "/" + - format_utils::BytesToHumanReadable(total)}); + (*bars)[i].set_progress(100); + auto total_str = format_utils::BytesToHumanReadable(total); + (*bars)[i].set_option( + indicators::option::PostfixText{total_str + "/" + total_str}); + + CTL_INF("Download success"); } } } @@ -104,6 +108,7 @@ bool DownloadProgress::Handle(const std::string& id) { ws_->poll(); ws_->dispatch(handle_message); } + indicators::show_console_cursor(true); if (status_ == DownloadStatus::DownloadError) return false; return true; diff --git a/engine/e2e-test/test_cli_engine_install.py b/engine/e2e-test/test_cli_engine_install.py index b4c27f3ef..c444c9bec 100644 --- a/engine/e2e-test/test_cli_engine_install.py +++ b/engine/e2e-test/test_cli_engine_install.py @@ -1,17 +1,29 @@ import platform import tempfile - +import os +from pathlib import Path import pytest from test_runner import run class TestCliEngineInstall: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() def test_engines_install_llamacpp_should_be_successfully(self): exit_code, output, error = run( - "Install Engine", ["engines", "install", "llama-cpp"], timeout=None + "Install Engine", ["engines", "install", "llama-cpp"], timeout=None, capture = False ) - assert "Start downloading" in output, "Should display downloading message" + root = Path.home() + assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp") assert exit_code == 0, f"Install engine failed with error: {error}" @pytest.mark.skipif(platform.system() != "Darwin", reason="macOS-specific test") @@ -32,9 +44,10 @@ def test_engines_install_onnx_on_tensorrt_should_be_failed(self): def test_engines_install_pre_release_llamacpp(self): exit_code, output, error = run( - "Install Engine", ["engines", "install", "llama-cpp", "-v", "v0.1.29"], timeout=600 + "Install Engine", ["engines", "install", "llama-cpp", "-v", "v0.1.29"], timeout=None, capture = False ) - assert "Start downloading" in output, "Should display downloading message" + root = Path.home() + assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp") assert exit_code == 0, f"Install engine failed with error: {error}" def test_engines_should_fallback_to_download_llamacpp_engine_if_not_exists(self): diff --git a/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py index 619833e16..8c3de8d98 100644 --- a/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py +++ b/engine/e2e-test/test_cli_model_pull_cortexso_with_selection.py @@ -1,8 +1,21 @@ from test_runner import popen +import os +from pathlib import Path class TestCliModelPullCortexsoWithSelection: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_pull_model_from_cortexso_should_display_list_and_allow_user_to_choose( self, ): diff --git a/engine/e2e-test/test_cli_model_pull_direct_url.py b/engine/e2e-test/test_cli_model_pull_direct_url.py index 4907ced1f..b10d1593d 100644 --- a/engine/e2e-test/test_cli_model_pull_direct_url.py +++ b/engine/e2e-test/test_cli_model_pull_direct_url.py @@ -1,8 +1,20 @@ from test_runner import run - +import os +from pathlib import Path class TestCliModelPullDirectUrl: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( "Pull model", @@ -10,8 +22,10 @@ def test_model_pull_with_direct_url_should_be_success(self): "pull", "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/blob/main/tinyllama-1.1b-chat-v0.3.Q2_K.gguf", ], - timeout=None, + timeout=None, capture=False ) + root = Path.home() + assert os.path.exists(root / "cortexcpp" / "models" / "huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v0.3-GGUF/tinyllama-1.1b-chat-v0.3.Q2_K.gguf") assert exit_code == 0, f"Model pull failed with error: {error}" # TODO: verify that the model has been pull successfully # TODO: skip this test. since download model is taking too long diff --git a/engine/e2e-test/test_cli_model_pull_from_cortexso.py b/engine/e2e-test/test_cli_model_pull_from_cortexso.py index c9c3f4c40..1791e39a6 100644 --- a/engine/e2e-test/test_cli_model_pull_from_cortexso.py +++ b/engine/e2e-test/test_cli_model_pull_from_cortexso.py @@ -4,6 +4,17 @@ class TestCliModelPullCortexso: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_model_pull_with_direct_url_should_be_success(self): exit_code, output, error = run( "Pull model", diff --git a/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py index 50b7e832b..996ac086c 100644 --- a/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py +++ b/engine/e2e-test/test_cli_model_pull_hugging_face_repository.py @@ -4,6 +4,17 @@ class TestCliModelPullHuggingFaceRepository: + def setup_and_teardown(self): + # Setup + success = start_server() + if not success: + raise Exception("Failed to start server") + + yield + + # Teardown + stop_server() + def test_model_pull_hugging_face_repository(self): """ Test pull model pervll/bge-reranker-v2-gemma-Q4_K_M-GGUF from issue #1017 diff --git a/engine/e2e-test/test_runner.py b/engine/e2e-test/test_runner.py index 320b8e332..20a8490a4 100644 --- a/engine/e2e-test/test_runner.py +++ b/engine/e2e-test/test_runner.py @@ -24,14 +24,14 @@ def getExecutablePath() -> str: # Execute a command -def run(test_name: str, arguments: List[str], timeout=timeout) -> (int, str, str): +def run(test_name: str, arguments: List[str], timeout=timeout, capture = True) -> (int, str, str): executable_path = getExecutablePath() print("Running:", test_name) print("Command:", [executable_path] + arguments) result = subprocess.run( [executable_path] + arguments, - capture_output=True, + capture_output=capture, text=True, timeout=timeout, ) From 3f2454e342965fa104f1b1b09dd7477de9898160 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 13:38:50 +0700 Subject: [PATCH 13/24] fix: run, start --- engine/cli/commands/model_pull_cmd.cc | 19 ++++++++++--------- engine/cli/commands/model_pull_cmd.h | 5 ++++- engine/cli/commands/model_start_cmd.cc | 2 +- engine/cli/commands/run_cmd.cc | 15 +++++++++------ engine/cli/commands/run_cmd.h | 6 +++--- 5 files changed, 27 insertions(+), 20 deletions(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index cbc62cb44..5ad6215ff 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -15,8 +15,8 @@ namespace commands { -void ModelPullCmd::Exec(const std::string& host, int port, - const std::string& input) { +std::optional ModelPullCmd::Exec(const std::string& host, int port, + const std::string& input) { // model_id: use to check the download progress // model: use as a parameter for pull API std::string model_id = input; @@ -27,7 +27,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, CLI_LOG("Starting server ..."); commands::ServerStartCmd ssc; if (!ssc.Exec(host, port)) { - return; + return std::nullopt; } } @@ -70,7 +70,7 @@ void ModelPullCmd::Exec(const std::string& host, int port, if (!selection.has_value()) { CLI_LOG("Invalid selection"); - return; + return std::nullopt; } model_id = selection.value(); model = model_id; @@ -84,12 +84,12 @@ void ModelPullCmd::Exec(const std::string& host, int port, } else { auto root = json_helper::ParseJsonString(res->body); CLI_LOG(root["message"].asString()); - return; + return std::nullopt; } } else { auto err = res.error(); CTL_ERR("HTTP error: " << httplib::to_string(err)); - return; + return std::nullopt; } Json::Value json_data; @@ -105,20 +105,21 @@ void ModelPullCmd::Exec(const std::string& host, int port, } else { auto root = json_helper::ParseJsonString(res->body); CLI_LOG(root["message"].asString()); - return; + return std::nullopt; } } else { auto err = res.error(); CTL_ERR("HTTP error: " << httplib::to_string(err)); - return; + return std::nullopt; } CLI_LOG("Start downloading ...") DownloadProgress dp; dp.Connect(host, port); if (!dp.Handle(model_id)) - return; + return std::nullopt; CLI_LOG("Model " << model_id << " downloaded successfully!") + return model_id; } }; // namespace commands diff --git a/engine/cli/commands/model_pull_cmd.h b/engine/cli/commands/model_pull_cmd.h index 444fc0bde..ebb85e52f 100644 --- a/engine/cli/commands/model_pull_cmd.h +++ b/engine/cli/commands/model_pull_cmd.h @@ -8,7 +8,10 @@ class ModelPullCmd { public: explicit ModelPullCmd(std::shared_ptr download_service) : model_service_{ModelService(download_service)} {}; - void Exec(const std::string& host, int port, const std::string& input); + explicit ModelPullCmd(const ModelService& model_service) + : model_service_{model_service} {}; + std::optional Exec(const std::string& host, int port, + const std::string& input); private: ModelService model_service_; diff --git a/engine/cli/commands/model_start_cmd.cc b/engine/cli/commands/model_start_cmd.cc index 9041e7e07..1055805f5 100644 --- a/engine/cli/commands/model_start_cmd.cc +++ b/engine/cli/commands/model_start_cmd.cc @@ -14,7 +14,7 @@ bool ModelStartCmd::Exec(const std::string& host, int port, const std::string& model_handle, bool print_success_log) { std::optional model_id = - SelectLocalModel(model_service_, model_handle); + SelectLocalModel(host, port, model_service_, model_handle); if (!model_id.has_value()) { return false; diff --git a/engine/cli/commands/run_cmd.cc b/engine/cli/commands/run_cmd.cc index 13e3be4e7..d09298cd5 100644 --- a/engine/cli/commands/run_cmd.cc +++ b/engine/cli/commands/run_cmd.cc @@ -3,16 +3,18 @@ #include "config/yaml_config.h" #include "cortex_upd_cmd.h" #include "database/models.h" +#include "engine_install_cmd.h" +#include "model_pull_cmd.h" #include "model_start_cmd.h" #include "model_status_cmd.h" #include "server_start_cmd.h" #include "utils/cli_selection_utils.h" #include "utils/logging_utils.h" -#include "engine_install_cmd.h" namespace commands { -std::optional SelectLocalModel(ModelService& model_service, +std::optional SelectLocalModel(std::string host, int port, + ModelService& model_service, const std::string& model_handle) { std::optional model_id = model_handle; cortex::db::Models modellist_handler; @@ -43,8 +45,8 @@ std::optional SelectLocalModel(ModelService& model_service, } else { auto related_models_ids = modellist_handler.FindRelatedModel(model_handle); if (related_models_ids.has_error() || related_models_ids.value().empty()) { - auto result = model_service.DownloadModel(model_handle); - if (result.has_error()) { + auto result = ModelPullCmd(model_service).Exec(host, port, model_handle); + if (!result) { CLI_LOG("Model " << model_handle << " not found!"); return std::nullopt; } @@ -80,7 +82,7 @@ std::string Repo2Engine(const std::string& r) { void RunCmd::Exec(bool run_detach) { std::optional model_id = - SelectLocalModel(model_service_, model_handle_); + SelectLocalModel(host_, port_, model_service_, model_handle_); if (!model_id.has_value()) { return; } @@ -115,7 +117,8 @@ void RunCmd::Exec(bool run_detach) { throw std::runtime_error("Engine " + mc.engine + " is incompatible"); } if (required_engine.value().status == EngineService::kNotInstalled) { - if(!EngineInstallCmd(download_service_, host_, port_).Exec(mc.engine)) { + if (!EngineInstallCmd(download_service_, host_, port_) + .Exec(mc.engine)) { return; } } diff --git a/engine/cli/commands/run_cmd.h b/engine/cli/commands/run_cmd.h index 7d3e60054..46a687fce 100644 --- a/engine/cli/commands/run_cmd.h +++ b/engine/cli/commands/run_cmd.h @@ -6,7 +6,8 @@ namespace commands { -std::optional SelectLocalModel(ModelService& model_service, +std::optional SelectLocalModel(std::string host, int port, + ModelService& model_service, const std::string& model_handle); class RunCmd { @@ -27,9 +28,8 @@ class RunCmd { int port_; std::string model_handle_; -std::shared_ptr download_service_; + std::shared_ptr download_service_; ModelService model_service_; EngineService engine_service_; - }; } // namespace commands From f7568390cdcbb3ae4b3d4188bce88bddf8074034 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 14:07:24 +0700 Subject: [PATCH 14/24] fix: start server --- engine/cli/commands/server_start_cmd.cc | 4 ++++ engine/e2e-test/test_cli_engine_install.py | 4 ++-- engine/e2e-test/test_cli_engine_uninstall.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/engine/cli/commands/server_start_cmd.cc b/engine/cli/commands/server_start_cmd.cc index cd06a3ba3..ca5363fa6 100644 --- a/engine/cli/commands/server_start_cmd.cc +++ b/engine/cli/commands/server_start_cmd.cc @@ -100,6 +100,10 @@ bool ServerStartCmd::Exec(const std::string& host, int port) { auto data_path = file_manager_utils::GetEnginesContainerPath(); auto llamacpp_path = data_path / "cortex.llamacpp/"; auto trt_path = data_path / "cortex.tensorrt-llm/"; + if (!std::filesystem::exists(llamacpp_path)) { + std::filesystem::create_directory(llamacpp_path); + } + auto new_v = trt_path.string() + ":" + llamacpp_path.string() + ":" + v; setenv(name, new_v.c_str(), true); CTL_INF("LD_LIBRARY_PATH: " << getenv(name)); diff --git a/engine/e2e-test/test_cli_engine_install.py b/engine/e2e-test/test_cli_engine_install.py index c444c9bec..572e62ed9 100644 --- a/engine/e2e-test/test_cli_engine_install.py +++ b/engine/e2e-test/test_cli_engine_install.py @@ -23,7 +23,7 @@ def test_engines_install_llamacpp_should_be_successfully(self): "Install Engine", ["engines", "install", "llama-cpp"], timeout=None, capture = False ) root = Path.home() - assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp") + assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp" / "version.txt") assert exit_code == 0, f"Install engine failed with error: {error}" @pytest.mark.skipif(platform.system() != "Darwin", reason="macOS-specific test") @@ -47,7 +47,7 @@ def test_engines_install_pre_release_llamacpp(self): "Install Engine", ["engines", "install", "llama-cpp", "-v", "v0.1.29"], timeout=None, capture = False ) root = Path.home() - assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp") + assert os.path.exists(root / "cortexcpp" / "engines" / "cortex.llamacpp" / "version.txt") assert exit_code == 0, f"Install engine failed with error: {error}" def test_engines_should_fallback_to_download_llamacpp_engine_if_not_exists(self): diff --git a/engine/e2e-test/test_cli_engine_uninstall.py b/engine/e2e-test/test_cli_engine_uninstall.py index 5190cee7a..23b621b0e 100644 --- a/engine/e2e-test/test_cli_engine_uninstall.py +++ b/engine/e2e-test/test_cli_engine_uninstall.py @@ -12,7 +12,7 @@ def setup_and_teardown(self): raise Exception("Failed to start server") # Preinstall llamacpp engine - run("Install Engine", ["engines", "install", "llama-cpp"],timeout = None) + run("Install Engine", ["engines", "install", "llama-cpp"],timeout = None, capture = False) yield From beb44eddb07afdd460d05a50c8b8d118f952c7a0 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Mon, 28 Oct 2024 14:22:48 +0700 Subject: [PATCH 15/24] fix: e2e --- engine/e2e-test/test_create_log_folder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/e2e-test/test_create_log_folder.py b/engine/e2e-test/test_create_log_folder.py index 8b667141b..5dbbd521c 100644 --- a/engine/e2e-test/test_create_log_folder.py +++ b/engine/e2e-test/test_create_log_folder.py @@ -10,6 +10,7 @@ class TestCreateLogFolder: @pytest.fixture(autouse=True) def setup_and_teardown(self): # Setup + stop_server() root = Path.home() if os.path.exists(root / "cortexcpp" / "logs"): shutil.rmtree(root / "cortexcpp" / "logs") From ca1b436995dc387d95c74a8e69c1b0abffa14ddf Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 07:43:13 +0700 Subject: [PATCH 16/24] fix: remove --- engine/cli/utils/download_progress.cc | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index 0bc3b2926..d8ded2361 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -48,29 +48,22 @@ bool DownloadProgress::Handle(const std::string& id) { json_helper::ParseJsonString(message)); // Ignore other task ids if (ev.download_task_.id != id) { - return; } status_ = ev.type_; - // std::cout << downloaded << " " << total << std::endl; + if (!bars) { bars = std::make_unique< indicators::DynamicProgress>(); for (auto& i : ev.download_task_.items) { items.emplace_back(std::make_unique( indicators::option::BarWidth{50}, indicators::option::Start{"["}, - // indicators::option::Fill{"■"}, indicators::option::Lead{"■"}, indicators::option::Fill{"="}, indicators::option::Lead{">"}, - // indicators::option::Remainder{" "}, indicators::option::End{"]"}, indicators::option::PrefixText{pad_string(i.id)}, - // indicators::option::PostfixText{"Downloading files"}, indicators::option::ForegroundColor{indicators::Color::white}, - indicators::option::ShowRemainingTime{true} - // indicators::option::FontStyles{std::vector{ - // indicators::FontStyle::bold}} - )); + indicators::option::ShowRemainingTime{true})); bars->push_back(*(items.back())); } } else { From a30da37c6ba2ce672473603ba17b333933376354 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 08:59:34 +0700 Subject: [PATCH 17/24] fix: abort model --- engine/cli/commands/model_pull_cmd.cc | 61 ++++++++++++++++++++++++++- engine/cli/commands/model_pull_cmd.h | 4 ++ engine/cli/utils/download_progress.h | 8 +++- 3 files changed, 69 insertions(+), 4 deletions(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 5ad6215ff..57bfc1e71 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -11,12 +11,19 @@ #include "utils/huggingface_utils.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" +#include "utils/scope_exit.h" #include "utils/string_utils.h" namespace commands { - +std::function shutdown_handler; +inline void signal_handler(int signal) { + if (shutdown_handler) { + shutdown_handler(signal); + } +} std::optional ModelPullCmd::Exec(const std::string& host, int port, const std::string& input) { + // model_id: use to check the download progress // model: use as a parameter for pull API std::string model_id = input; @@ -92,6 +99,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, return std::nullopt; } + // Send request download model to server Json::Value json_data; json_data["model"] = model; auto data_str = json_data.toStyledString(); @@ -115,11 +123,60 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, CLI_LOG("Start downloading ...") DownloadProgress dp; + bool force_stop = false; + + shutdown_handler = [this, &dp, &host, &port, &model_id, &force_stop](int) { + force_stop = true; + AbortModelPull(host, port, model_id); + dp.ForceStop(); + }; + + utils::ScopeExit se([]() { shutdown_handler = {}; }); +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset(&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined(_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler( + reinterpret_cast(console_ctrl_handler), true); +#endif dp.Connect(host, port); if (!dp.Handle(model_id)) return std::nullopt; - + if (force_stop) + return std::nullopt; CLI_LOG("Model " << model_id << " downloaded successfully!") return model_id; } + +bool ModelPullCmd::AbortModelPull(const std::string& host, int port, + const std::string& task_id) { + Json::Value json_data; + json_data["taskId"] = task_id; + auto data_str = json_data.toStyledString(); + httplib::Client cli(host + ":" + std::to_string(port)); + cli.set_read_timeout(std::chrono::seconds(60)); + auto res = cli.Delete("/v1/models/pull", httplib::Headers(), data_str.data(), + data_str.size(), "application/json"); + if (res) { + if (res->status == httplib::StatusCode::OK_200) { + std::cout << "OK" << std::endl; + return true; + } else { + auto root = json_helper::ParseJsonString(res->body); + CLI_LOG(root["message"].asString()); + return false; + } + } else { + auto err = res.error(); + CTL_ERR("HTTP error: " << httplib::to_string(err)); + return false; + } +} }; // namespace commands diff --git a/engine/cli/commands/model_pull_cmd.h b/engine/cli/commands/model_pull_cmd.h index ebb85e52f..d05759dbc 100644 --- a/engine/cli/commands/model_pull_cmd.h +++ b/engine/cli/commands/model_pull_cmd.h @@ -13,6 +13,10 @@ class ModelPullCmd { std::optional Exec(const std::string& host, int port, const std::string& input); + private: + bool AbortModelPull(const std::string& host, int port, + const std::string& task_id); + private: ModelService model_service_; }; diff --git a/engine/cli/utils/download_progress.h b/engine/cli/utils/download_progress.h index 6511b9537..4f71e6d84 100644 --- a/engine/cli/utils/download_progress.h +++ b/engine/cli/utils/download_progress.h @@ -12,13 +12,17 @@ class DownloadProgress { bool Handle(const std::string& id); + void ForceStop() { force_stop_ = true; } + private: bool should_stop() const { - return status_ != DownloadStatus::DownloadStarted && - status_ != DownloadStatus::DownloadUpdated; + return (status_ != DownloadStatus::DownloadStarted && + status_ != DownloadStatus::DownloadUpdated) || + force_stop_; } private: std::unique_ptr ws_; std::atomic status_ = DownloadStatus::DownloadStarted; + std::atomic force_stop_ = false; }; \ No newline at end of file From 2703b98f919d218b09434bfdba10f488a1b26447 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 09:06:13 +0700 Subject: [PATCH 18/24] fix: build --- engine/cli/commands/model_pull_cmd.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 57bfc1e71..098469585 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -13,6 +13,9 @@ #include "utils/logging_utils.h" #include "utils/scope_exit.h" #include "utils/string_utils.h" +#if defined(_WIN32) +#include +#endif namespace commands { std::function shutdown_handler; @@ -166,7 +169,7 @@ bool ModelPullCmd::AbortModelPull(const std::string& host, int port, data_str.size(), "application/json"); if (res) { if (res->status == httplib::StatusCode::OK_200) { - std::cout << "OK" << std::endl; + CTL_INF("Abort model pull successfully: " << task_id); return true; } else { auto root = json_helper::ParseJsonString(res->body); From 300f199abcf104c92e4882dff004724e91be99cb Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 09:25:09 +0700 Subject: [PATCH 19/24] fix: clean code --- engine/cli/commands/model_pull_cmd.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 098469585..90f95e96f 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -111,9 +111,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, data_str.size(), "application/json"); if (res) { - if (res->status == httplib::StatusCode::OK_200) { - - } else { + if (res->status != httplib::StatusCode::OK_200) { auto root = json_helper::ParseJsonString(res->body); CLI_LOG(root["message"].asString()); return std::nullopt; From d355bf39d78dd1411bdf1f9edd43164a79e8e7e2 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 09:28:07 +0700 Subject: [PATCH 20/24] fix: clean more --- engine/cli/commands/engine_install_cmd.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index 27017f1bf..b116e03c7 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -2,8 +2,8 @@ #include "server_start_cmd.h" #include "utils/download_progress.h" #include "utils/engine_constants.h" -#include "utils/logging_utils.h" #include "utils/json_helper.h" +#include "utils/logging_utils.h" namespace commands { namespace { @@ -23,13 +23,13 @@ bool EngineInstallCmd::Exec(const std::string& engine, const std::string& version, const std::string& src) { // Handle local install, if fails, fallback to remote install - if (!src.empty()) { + if (!src.empty()) { auto res = engine_service_.UnzipEngine(engine, version, src); - if(res.has_error()) { + if (res.has_error()) { CLI_LOG(res.error()); return false; } - if(res.value()) { + if (res.value()) { CLI_LOG("Engine " << engine << " installed successfully!"); return true; } @@ -52,8 +52,7 @@ bool EngineInstallCmd::Exec(const std::string& engine, data_str.data(), data_str.size(), "application/json"); if (res) { - if (res->status == httplib::StatusCode::OK_200) { - } else { + if (res->status != httplib::StatusCode::OK_200) { auto root = json_helper::ParseJsonString(res->body); CLI_LOG(root["message"].asString()); return false; From cbe37e0ed9cf487eda5f3b531298972ebd887c50 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 09:44:00 +0700 Subject: [PATCH 21/24] fix: normalize engine id --- engine/cli/commands/engine_install_cmd.cc | 15 +-------------- engine/services/engine_service.cc | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index b116e03c7..4cb9c0277 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -6,19 +6,6 @@ #include "utils/logging_utils.h" namespace commands { -namespace { -std::string NormalizeEngine(const std::string& engine) { - if (engine == kLlamaEngine) { - return kLlamaRepo; - } else if (engine == kOnnxEngine) { - return kOnnxRepo; - } else if (engine == kTrtLlmEngine) { - return kTrtLlmRepo; - } - return engine; -}; -} // namespace - bool EngineInstallCmd::Exec(const std::string& engine, const std::string& version, const std::string& src) { @@ -66,7 +53,7 @@ bool EngineInstallCmd::Exec(const std::string& engine, CLI_LOG("Start downloading ...") DownloadProgress dp; dp.Connect(host_, port_); - if (!dp.Handle(NormalizeEngine(engine))) + if (!dp.Handle(engine)) return false; bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 5e706be27..9d2ef42c0 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -44,6 +44,17 @@ std::string NormalizeEngine(const std::string& engine) { } return engine; }; + +std::string Repo2Engine(const std::string& r) { + if (r == kLlamaRepo) { + return kLlamaEngine; + } else if (r == kOnnxRepo) { + return kOnnxEngine; + } else if (r == kTrtLlmRepo) { + return kTrtLlmEngine; + } + return r; +}; } // namespace cpp::result EngineService::GetEngineInfo( @@ -314,10 +325,10 @@ cpp::result EngineService::DownloadEngine( CTL_INF("Engine folder path: " << engine_folder_path.string() << "\n"); auto local_path = engine_folder_path / file_name; - auto downloadTask{DownloadTask{.id = engine, + auto downloadTask{DownloadTask{.id = Repo2Engine(engine), .type = DownloadType::Engine, .items = {DownloadItem{ - .id = engine, + .id = Repo2Engine(engine), .downloadUrl = download_url, .localPath = local_path, }}}}; From 4fd7d5d4c8cc4d754bb316e7f2e06273da8a4efd Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 09:52:26 +0700 Subject: [PATCH 22/24] fix: use auto --- engine/cli/commands/model_pull_cmd.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 90f95e96f..189b9268d 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -29,8 +29,8 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, // model_id: use to check the download progress // model: use as a parameter for pull API - std::string model_id = input; - std::string model = input; + auto model_id = input; + auto model = input; // Start server if server is not started yet if (!commands::IsServerAlive(host, port)) { @@ -54,9 +54,9 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, if (res->status == httplib::StatusCode::OK_200) { // CLI_LOG(res->body); auto root = json_helper::ParseJsonString(res->body); - std::string id = root["id"].asString(); + auto id = root["id"].asString(); bool is_cortexso = root["modelSource"].asString() == "cortexso"; - std::string default_branch = root["defaultBranch"].asString(); + auto default_branch = root["defaultBranch"].asString(); std::vector downloaded; for (auto const& v : root["downloadedModels"]) { downloaded.push_back(v.asString()); @@ -65,7 +65,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, for (auto const& v : root["availableModels"]) { avails.push_back(v.asString()); } - std::string download_url = root["downloadUrl"].asString(); + auto download_url = root["downloadUrl"].asString(); if (downloaded.empty() && avails.empty()) { model_id = id; From c92bc689f09d00990e71e254565846a0f090c94f Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 10:09:49 +0700 Subject: [PATCH 23/24] fix: use vcpkg for indicators --- engine/cli/CMakeLists.txt | 3 + engine/cli/commands/model_pull_cmd.cc | 2 - engine/cli/utils/download_progress.cc | 3 +- engine/cli/utils/indicators.hpp | 3257 ------------------------- engine/vcpkg.json | 3 +- 5 files changed, 7 insertions(+), 3261 deletions(-) delete mode 100644 engine/cli/utils/indicators.hpp diff --git a/engine/cli/CMakeLists.txt b/engine/cli/CMakeLists.txt index da234132d..19f206a40 100644 --- a/engine/cli/CMakeLists.txt +++ b/engine/cli/CMakeLists.txt @@ -71,6 +71,8 @@ find_package(tabulate CONFIG REQUIRED) find_package(CURL REQUIRED) find_package(SQLiteCpp REQUIRED) find_package(Trantor CONFIG REQUIRED) +find_package(indicators CONFIG REQUIRED) + add_executable(${TARGET_NAME} main.cc ${CMAKE_CURRENT_SOURCE_DIR}/../utils/cpuid/cpu_info.cc @@ -95,6 +97,7 @@ target_link_libraries(${TARGET_NAME} PRIVATE JsonCpp::JsonCpp OpenSSL::SSL OpenS ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET_NAME} PRIVATE SQLiteCpp) target_link_libraries(${TARGET_NAME} PRIVATE Trantor::Trantor) +target_link_libraries(${TARGET_NAME} PRIVATE indicators::indicators) # ############################################################################## diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 189b9268d..3a8f202d3 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -1,7 +1,5 @@ #include "model_pull_cmd.h" #include -#include "cli/utils/easywsclient.hpp" -#include "cli/utils/indicators.hpp" #include "common/event.h" #include "database/models.h" #include "server_start_cmd.h" diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index d8ded2361..d7c48d3a6 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -1,7 +1,8 @@ #include "download_progress.h" #include #include "common/event.h" -#include "indicators.hpp" +#include "indicators/dynamic_progress.hpp" +#include "indicators/progress_bar.hpp" #include "utils/format_utils.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" diff --git a/engine/cli/utils/indicators.hpp b/engine/cli/utils/indicators.hpp deleted file mode 100644 index f034c9441..000000000 --- a/engine/cli/utils/indicators.hpp +++ /dev/null @@ -1,3257 +0,0 @@ - -#ifndef INDICATORS_COLOR -#define INDICATORS_COLOR - -namespace indicators { -enum class Color { - grey, - red, - green, - yellow, - blue, - magenta, - cyan, - white, - unspecified -}; -} - -#endif - -#ifndef INDICATORS_FONT_STYLE -#define INDICATORS_FONT_STYLE - -namespace indicators { -enum class FontStyle { - bold, - dark, - italic, - underline, - blink, - reverse, - concealed, - crossed -}; -} - -#endif - -#ifndef INDICATORS_PROGRESS_TYPE -#define INDICATORS_PROGRESS_TYPE - -namespace indicators { -enum class ProgressType { incremental, decremental }; -} - -#endif - -//! -//! termcolor -//! ~~~~~~~~~ -//! -//! termcolor is a header-only c++ library for printing colored messages -//! to the terminal. Written just for fun with a help of the Force. -//! -//! :copyright: (c) 2013 by Ihor Kalnytskyi -//! :license: BSD, see LICENSE for details -//! - -#ifndef TERMCOLOR_HPP_ -#define TERMCOLOR_HPP_ - -#include -#include -#include - -// Detect target's platform and set some macros in order to wrap platform -// specific code this library depends on. -#if defined(_WIN32) || defined(_WIN64) -#define TERMCOLOR_TARGET_WINDOWS -#elif defined(__unix__) || defined(__unix) || \ - (defined(__APPLE__) && defined(__MACH__)) -#define TERMCOLOR_TARGET_POSIX -#endif - -// If implementation has not been explicitly set, try to choose one based on -// target platform. -#if !defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) && \ - !defined(TERMCOLOR_USE_WINDOWS_API) && !defined(TERMCOLOR_USE_NOOP) -#if defined(TERMCOLOR_TARGET_POSIX) -#define TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES -#define TERMCOLOR_AUTODETECTED_IMPLEMENTATION -#elif defined(TERMCOLOR_TARGET_WINDOWS) -#define TERMCOLOR_USE_WINDOWS_API -#define TERMCOLOR_AUTODETECTED_IMPLEMENTATION -#endif -#endif - -// These headers provide isatty()/fileno() functions, which are used for -// testing whether a standard stream refers to the terminal. -#if defined(TERMCOLOR_TARGET_POSIX) -#include -#elif defined(TERMCOLOR_TARGET_WINDOWS) -#if defined(_MSC_VER) -#if !defined(NOMINMAX) -#define NOMINMAX -#endif -#endif -#include -#include -#endif - -namespace termcolor { -// Forward declaration of the `_internal` namespace. -// All comments are below. -namespace _internal { -inline int colorize_index(); -inline FILE* get_standard_stream(const std::ostream& stream); -inline bool is_colorized(std::ostream& stream); -inline bool is_atty(const std::ostream& stream); - -#if defined(TERMCOLOR_TARGET_WINDOWS) -inline void win_change_attributes(std::ostream& stream, int foreground, - int background = -1); -#endif -} // namespace _internal - -inline std::ostream& colorize(std::ostream& stream) { - stream.iword(_internal::colorize_index()) = 1L; - return stream; -} - -inline std::ostream& nocolorize(std::ostream& stream) { - stream.iword(_internal::colorize_index()) = 0L; - return stream; -} - -inline std::ostream& reset(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[00m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, -1); -#endif - } - return stream; -} - -inline std::ostream& bold(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[1m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -inline std::ostream& dark(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[2m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -inline std::ostream& italic(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[3m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -inline std::ostream& underline(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[4m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, COMMON_LVB_UNDERSCORE); -#endif - } - return stream; -} - -inline std::ostream& blink(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[5m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -inline std::ostream& reverse(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[7m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -inline std::ostream& concealed(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[8m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -inline std::ostream& crossed(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[9m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -template -inline std::ostream& color(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - char command[12]; - std::snprintf(command, sizeof(command), "\033[38;5;%dm", code); - stream << command; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -template -inline std::ostream& on_color(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - char command[12]; - std::snprintf(command, sizeof(command), "\033[48;5;%dm", code); - stream << command; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -template -inline std::ostream& color(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - char command[20]; - std::snprintf(command, sizeof(command), "\033[38;2;%d;%d;%dm", r, g, b); - stream << command; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -template -inline std::ostream& on_color(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - char command[20]; - std::snprintf(command, sizeof(command), "\033[48;2;%d;%d;%dm", r, g, b); - stream << command; -#elif defined(TERMCOLOR_USE_WINDOWS_API) -#endif - } - return stream; -} - -inline std::ostream& grey(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[30m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, - 0 // grey (black) - ); -#endif - } - return stream; -} - -inline std::ostream& red(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[31m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, FOREGROUND_RED); -#endif - } - return stream; -} - -inline std::ostream& green(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[32m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, FOREGROUND_GREEN); -#endif - } - return stream; -} - -inline std::ostream& yellow(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[33m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, FOREGROUND_GREEN | FOREGROUND_RED); -#endif - } - return stream; -} - -inline std::ostream& blue(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[34m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, FOREGROUND_BLUE); -#endif - } - return stream; -} - -inline std::ostream& magenta(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[35m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, FOREGROUND_BLUE | FOREGROUND_RED); -#endif - } - return stream; -} - -inline std::ostream& cyan(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[36m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, - FOREGROUND_BLUE | FOREGROUND_GREEN); -#endif - } - return stream; -} - -inline std::ostream& white(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[37m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED); -#endif - } - return stream; -} - -inline std::ostream& bright_grey(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[90m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, - 0 | FOREGROUND_INTENSITY // grey (black) - ); -#endif - } - return stream; -} - -inline std::ostream& bright_red(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[91m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, - FOREGROUND_RED | FOREGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& bright_green(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[92m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, - FOREGROUND_GREEN | FOREGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& bright_yellow(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[93m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, FOREGROUND_GREEN | FOREGROUND_RED | FOREGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& bright_blue(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[94m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, - FOREGROUND_BLUE | FOREGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& bright_magenta(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[95m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, FOREGROUND_BLUE | FOREGROUND_RED | FOREGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& bright_cyan(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[96m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& bright_white(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[97m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, FOREGROUND_BLUE | FOREGROUND_GREEN | FOREGROUND_RED | - FOREGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& on_grey(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[40m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - 0 // grey (black) - ); -#endif - } - return stream; -} - -inline std::ostream& on_red(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[41m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, BACKGROUND_RED); -#endif - } - return stream; -} - -inline std::ostream& on_green(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[42m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, BACKGROUND_GREEN); -#endif - } - return stream; -} - -inline std::ostream& on_yellow(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[43m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - BACKGROUND_GREEN | BACKGROUND_RED); -#endif - } - return stream; -} - -inline std::ostream& on_blue(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[44m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, BACKGROUND_BLUE); -#endif - } - return stream; -} - -inline std::ostream& on_magenta(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[45m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - BACKGROUND_BLUE | BACKGROUND_RED); -#endif - } - return stream; -} - -inline std::ostream& on_cyan(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[46m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - BACKGROUND_GREEN | BACKGROUND_BLUE); -#endif - } - return stream; -} - -inline std::ostream& on_white(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[47m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, -1, BACKGROUND_GREEN | BACKGROUND_BLUE | BACKGROUND_RED); -#endif - } - - return stream; -} - -inline std::ostream& on_bright_grey(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[100m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - 0 | BACKGROUND_INTENSITY // grey (black) - ); -#endif - } - return stream; -} - -inline std::ostream& on_bright_red(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[101m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - BACKGROUND_RED | BACKGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& on_bright_green(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[102m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - BACKGROUND_GREEN | BACKGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& on_bright_yellow(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[103m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, -1, BACKGROUND_GREEN | BACKGROUND_RED | BACKGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& on_bright_blue(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[104m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - BACKGROUND_BLUE | BACKGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& on_bright_magenta(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[105m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, -1, BACKGROUND_BLUE | BACKGROUND_RED | BACKGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& on_bright_cyan(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[106m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes( - stream, -1, BACKGROUND_GREEN | BACKGROUND_BLUE | BACKGROUND_INTENSITY); -#endif - } - return stream; -} - -inline std::ostream& on_bright_white(std::ostream& stream) { - if (_internal::is_colorized(stream)) { -#if defined(TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES) - stream << "\033[107m"; -#elif defined(TERMCOLOR_USE_WINDOWS_API) - _internal::win_change_attributes(stream, -1, - BACKGROUND_GREEN | BACKGROUND_BLUE | - BACKGROUND_RED | BACKGROUND_INTENSITY); -#endif - } - - return stream; -} - -//! Since C++ hasn't a way to hide something in the header from -//! the outer access, I have to introduce this namespace which -//! is used for internal purpose and should't be access from -//! the user code. -namespace _internal { -// An index to be used to access a private storage of I/O streams. See -// colorize / nocolorize I/O manipulators for details. Due to the fact -// that static variables ain't shared between translation units, inline -// function with local static variable is used to do the trick and share -// the variable value between translation units. -inline int colorize_index() { - static int colorize_index = std::ios_base::xalloc(); - return colorize_index; -} - -//! Since C++ hasn't a true way to extract stream handler -//! from the a given `std::ostream` object, I have to write -//! this kind of hack. -inline FILE* get_standard_stream(const std::ostream& stream) { - if (&stream == &std::cout) - return stdout; - else if ((&stream == &std::cerr) || (&stream == &std::clog)) - return stderr; - - return nullptr; -} - -// Say whether a given stream should be colorized or not. It's always -// true for ATTY streams and may be true for streams marked with -// colorize flag. -inline bool is_colorized(std::ostream& stream) { - return is_atty(stream) || static_cast(stream.iword(colorize_index())); -} - -//! Test whether a given `std::ostream` object refers to -//! a terminal. -inline bool is_atty(const std::ostream& stream) { - FILE* std_stream = get_standard_stream(stream); - - // Unfortunately, fileno() ends with segmentation fault - // if invalid file descriptor is passed. So we need to - // handle this case gracefully and assume it's not a tty - // if standard stream is not detected, and 0 is returned. - if (!std_stream) - return false; - -#if defined(TERMCOLOR_TARGET_POSIX) - return ::isatty(fileno(std_stream)); -#elif defined(TERMCOLOR_TARGET_WINDOWS) - return ::_isatty(_fileno(std_stream)); -#else - return false; -#endif -} - -#if defined(TERMCOLOR_TARGET_WINDOWS) -//! Change Windows Terminal colors attribute. If some -//! parameter is `-1` then attribute won't changed. -inline void win_change_attributes(std::ostream& stream, int foreground, - int background) { - // yeah, i know.. it's ugly, it's windows. - static WORD defaultAttributes = 0; - - // Windows doesn't have ANSI escape sequences and so we use special - // API to change Terminal output color. That means we can't - // manipulate colors by means of "std::stringstream" and hence - // should do nothing in this case. - if (!_internal::is_atty(stream)) - return; - - // get terminal handle - HANDLE hTerminal = INVALID_HANDLE_VALUE; - if (&stream == &std::cout) - hTerminal = GetStdHandle(STD_OUTPUT_HANDLE); - else if (&stream == &std::cerr) - hTerminal = GetStdHandle(STD_ERROR_HANDLE); - - // save default terminal attributes if it unsaved - if (!defaultAttributes) { - CONSOLE_SCREEN_BUFFER_INFO info; - if (!GetConsoleScreenBufferInfo(hTerminal, &info)) - return; - defaultAttributes = info.wAttributes; - } - - // restore all default settings - if (foreground == -1 && background == -1) { - SetConsoleTextAttribute(hTerminal, defaultAttributes); - return; - } - - // get current settings - CONSOLE_SCREEN_BUFFER_INFO info; - if (!GetConsoleScreenBufferInfo(hTerminal, &info)) - return; - - if (foreground != -1) { - info.wAttributes &= ~(info.wAttributes & 0x0F); - info.wAttributes |= static_cast(foreground); - } - - if (background != -1) { - info.wAttributes &= ~(info.wAttributes & 0xF0); - info.wAttributes |= static_cast(background); - } - - SetConsoleTextAttribute(hTerminal, info.wAttributes); -} -#endif // TERMCOLOR_TARGET_WINDOWS - -} // namespace _internal - -} // namespace termcolor - -#undef TERMCOLOR_TARGET_POSIX -#undef TERMCOLOR_TARGET_WINDOWS - -#if defined(TERMCOLOR_AUTODETECTED_IMPLEMENTATION) -#undef TERMCOLOR_USE_ANSI_ESCAPE_SEQUENCES -#undef TERMCOLOR_USE_WINDOWS_API -#endif - -#endif // TERMCOLOR_HPP_ - -#ifndef INDICATORS_TERMINAL_SIZE -#define INDICATORS_TERMINAL_SIZE -#include - -#if defined(_WIN32) -#include - -namespace indicators { - -static inline std::pair terminal_size() { - CONSOLE_SCREEN_BUFFER_INFO csbi; - int cols, rows; - GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); - cols = csbi.srWindow.Right - csbi.srWindow.Left + 1; - rows = csbi.srWindow.Bottom - csbi.srWindow.Top + 1; - return {static_cast(rows), static_cast(cols)}; -} - -static inline size_t terminal_width() { - return terminal_size().second; -} - -} // namespace indicators - -#else - -#include //ioctl() and TIOCGWINSZ -#include // for STDOUT_FILENO - -namespace indicators { - -static inline std::pair terminal_size() { - struct winsize size {}; - ioctl(STDOUT_FILENO, TIOCGWINSZ, &size); - return {static_cast(size.ws_row), static_cast(size.ws_col)}; -} - -static inline size_t terminal_width() { - return terminal_size().second; -} - -} // namespace indicators - -#endif - -#endif - -/* -Activity Indicators for Modern C++ -https://github.com/p-ranav/indicators - -Licensed under the MIT License . -SPDX-License-Identifier: MIT -Copyright (c) 2019 Dawid Pilarski . - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -*/ -#ifndef INDICATORS_SETTING -#define INDICATORS_SETTING - -#include -// #include -// #include -// #include -#include -#include -#include -#include -#include - -namespace indicators { - -namespace details { - -template -struct if_else; - -template <> -struct if_else { - using type = std::true_type; -}; - -template <> -struct if_else { - using type = std::false_type; -}; - -template -struct if_else_type; - -template -struct if_else_type { - using type = True; -}; - -template -struct if_else_type { - using type = False; -}; - -template -struct conjuction; - -template <> -struct conjuction<> : std::true_type {}; - -template -struct conjuction - : if_else_type>::type { -}; - -template -struct disjunction; - -template <> -struct disjunction<> : std::false_type {}; - -template -struct disjunction - : if_else_type>::type {}; - -enum class ProgressBarOption { - bar_width = 0, - prefix_text, - postfix_text, - start, - end, - fill, - lead, - remainder, - max_postfix_text_len, - completed, - show_percentage, - show_elapsed_time, - show_remaining_time, - saved_start_time, - foreground_color, - spinner_show, - spinner_states, - font_styles, - hide_bar_when_complete, - min_progress, - max_progress, - progress_type, - stream -}; - -template -struct Setting { - template ::value>::type> - explicit Setting(Args&&... args) : value(std::forward(args)...) {} - Setting(const Setting&) = default; - Setting(Setting&&) = default; - - static constexpr auto id = Id; - using type = T; - - T value{}; -}; - -template -struct is_setting : std::false_type {}; - -template -struct is_setting> : std::true_type {}; - -template -struct are_settings : if_else...>::value>::type {}; - -template <> -struct are_settings<> : std::true_type {}; - -template -struct is_setting_from_tuple; - -template -struct is_setting_from_tuple> : std::true_type {}; - -template -struct is_setting_from_tuple> - : if_else...>::value>::type { -}; - -template -struct are_settings_from_tuple - : if_else< - conjuction...>::value>::type { -}; - -template -struct always_true { - static constexpr auto value = true; -}; - -template -Default&& get_impl(Default&& def) { - return std::forward(def); -} - -template -auto get_impl(Default&& /*def*/, T&& first, Args&&... /*tail*/) -> - typename std::enable_if<(std::decay::type::id == Id), - decltype(std::forward(first))>::type { - return std::forward(first); -} - -template -auto get_impl(Default&& def, T&& /*first*/, Args&&... tail) -> - typename std::enable_if< - (std::decay::type::id != Id), - decltype(get_impl(std::forward(def), - std::forward(tail)...))>::type { - return get_impl(std::forward(def), std::forward(tail)...); -} - -template ::value, void>::type> -auto get(Default&& def, Args&&... args) - -> decltype(details::get_impl(std::forward(def), - std::forward(args)...)) { - return details::get_impl(std::forward(def), - std::forward(args)...); -} - -template -using StringSetting = Setting; - -template -using IntegerSetting = Setting; - -template -using BooleanSetting = Setting; - -template -struct option_idx; - -template -struct option_idx, counter> - : if_else_type<(Id == T::id), std::integral_constant, - option_idx, counter + 1>>::type { -}; - -template -struct option_idx, counter> { - static_assert(always_true<(ProgressBarOption)Id>::value, - "No such option was found"); -}; - -template -auto get_value(Settings&& settings) - -> decltype(( - std::get::type>::value>( - std::declval()))) { - return std::get::type>::value>( - std::forward(settings)); -} - -} // namespace details - -namespace option { -using BarWidth = details::IntegerSetting; -using PrefixText = - details::StringSetting; -using PostfixText = - details::StringSetting; -using Start = details::StringSetting; -using End = details::StringSetting; -using Fill = details::StringSetting; -using Lead = details::StringSetting; -using Remainder = details::StringSetting; -using MaxPostfixTextLen = - details::IntegerSetting; -using Completed = - details::BooleanSetting; -using ShowPercentage = - details::BooleanSetting; -using ShowElapsedTime = - details::BooleanSetting; -using ShowRemainingTime = - details::BooleanSetting; -using SavedStartTime = - details::BooleanSetting; -using ForegroundColor = - details::Setting; -using ShowSpinner = - details::BooleanSetting; -using SpinnerStates = - details::Setting, - details::ProgressBarOption::spinner_states>; -using HideBarWhenComplete = - details::BooleanSetting; -using FontStyles = details::Setting, - details::ProgressBarOption::font_styles>; -using MinProgress = - details::IntegerSetting; -using MaxProgress = - details::IntegerSetting; -using ProgressType = - details::Setting; -using Stream = - details::Setting; -} // namespace option -} // namespace indicators - -#endif - -#ifndef INDICATORS_CURSOR_CONTROL -#define INDICATORS_CURSOR_CONTROL - -#if defined(_MSC_VER) -#if !defined(NOMINMAX) -#define NOMINMAX -#endif -#include -#include -#else -#include -#endif - -namespace indicators { - -#if defined(_MSC_VER) - -static inline void show_console_cursor(bool const show) { - HANDLE out = GetStdHandle(STD_OUTPUT_HANDLE); - - CONSOLE_CURSOR_INFO cursorInfo; - - GetConsoleCursorInfo(out, &cursorInfo); - cursorInfo.bVisible = show; // set the cursor visibility - SetConsoleCursorInfo(out, &cursorInfo); -} - -static inline void erase_line() { - auto hStdout = GetStdHandle(STD_OUTPUT_HANDLE); - if (!hStdout) - return; - - CONSOLE_SCREEN_BUFFER_INFO csbiInfo; - GetConsoleScreenBufferInfo(hStdout, &csbiInfo); - - COORD cursor; - - cursor.X = 0; - cursor.Y = csbiInfo.dwCursorPosition.Y; - - DWORD count = 0; - - FillConsoleOutputCharacterA(hStdout, ' ', csbiInfo.dwSize.X, cursor, &count); - - FillConsoleOutputAttribute(hStdout, csbiInfo.wAttributes, csbiInfo.dwSize.X, - cursor, &count); - - SetConsoleCursorPosition(hStdout, cursor); -} - -#else - -static inline void show_console_cursor(bool const show) { - std::fputs(show ? "\033[?25h" : "\033[?25l", stdout); -} - -static inline void erase_line() { - std::fputs("\r\033[K", stdout); -} - -#endif - -} // namespace indicators - -#endif - -#ifndef INDICATORS_CURSOR_MOVEMENT -#define INDICATORS_CURSOR_MOVEMENT - -#if defined(_MSC_VER) -#if !defined(NOMINMAX) -#define NOMINMAX -#endif -#include -#include -#else -#include -#endif - -namespace indicators { - -#ifdef _MSC_VER - -static inline void move(int x, int y) { - auto hStdout = GetStdHandle(STD_OUTPUT_HANDLE); - if (!hStdout) - return; - - CONSOLE_SCREEN_BUFFER_INFO csbiInfo; - GetConsoleScreenBufferInfo(hStdout, &csbiInfo); - - COORD cursor; - - cursor.X = csbiInfo.dwCursorPosition.X + x; - cursor.Y = csbiInfo.dwCursorPosition.Y + y; - SetConsoleCursorPosition(hStdout, cursor); -} - -static inline void move_up(int lines) { - move(0, -lines); -} -static inline void move_down(int lines) { - move(0, -lines); -} -static inline void move_right(int cols) { - move(cols, 0); -} -static inline void move_left(int cols) { - move(-cols, 0); -} - -#else - -static inline void move_up(int lines) { - std::cout << "\033[" << lines << "A"; -} -static inline void move_down(int lines) { - std::cout << "\033[" << lines << "B"; -} -static inline void move_right(int cols) { - std::cout << "\033[" << cols << "C"; -} -static inline void move_left(int cols) { - std::cout << "\033[" << cols << "D"; -} - -#endif - -} // namespace indicators - -#endif - -#ifndef INDICATORS_STREAM_HELPER -#define INDICATORS_STREAM_HELPER - -// #include -#ifndef INDICATORS_DISPLAY_WIDTH -#define INDICATORS_DISPLAY_WIDTH - -#include -#include -#include -#include -#include - -namespace unicode { - -namespace details { - -/* - * This is an implementation of wcwidth() and wcswidth() (defined in - * IEEE Std 1002.1-2001) for Unicode. - * - * http://www.opengroup.org/onlinepubs/007904975/functions/wcwidth.html - * http://www.opengroup.org/onlinepubs/007904975/functions/wcswidth.html - * - * In fixed-width output devices, Latin characters all occupy a single - * "cell" position of equal width, whereas ideographic CJK characters - * occupy two such cells. Interoperability between terminal-line - * applications and (teletype-style) character terminals using the - * UTF-8 encoding requires agreement on which character should advance - * the cursor by how many cell positions. No established formal - * standards exist at present on which Unicode character shall occupy - * how many cell positions on character terminals. These routines are - * a first attempt of defining such behavior based on simple rules - * applied to data provided by the Unicode Consortium. - * - * For some graphical characters, the Unicode standard explicitly - * defines a character-cell width via the definition of the East Asian - * FullWidth (F), Wide (W), Half-width (H), and Narrow (Na) classes. - * In all these cases, there is no ambiguity about which width a - * terminal shall use. For characters in the East Asian Ambiguous (A) - * class, the width choice depends purely on a preference of backward - * compatibility with either historic CJK or Western practice. - * Choosing single-width for these characters is easy to justify as - * the appropriate long-term solution, as the CJK practice of - * displaying these characters as double-width comes from historic - * implementation simplicity (8-bit encoded characters were displayed - * single-width and 16-bit ones double-width, even for Greek, - * Cyrillic, etc.) and not any typographic considerations. - * - * Much less clear is the choice of width for the Not East Asian - * (Neutral) class. Existing practice does not dictate a width for any - * of these characters. It would nevertheless make sense - * typographically to allocate two character cells to characters such - * as for instance EM SPACE or VOLUME INTEGRAL, which cannot be - * represented adequately with a single-width glyph. The following - * routines at present merely assign a single-cell width to all - * neutral characters, in the interest of simplicity. This is not - * entirely satisfactory and should be reconsidered before - * establishing a formal standard in this area. At the moment, the - * decision which Not East Asian (Neutral) characters should be - * represented by double-width glyphs cannot yet be answered by - * applying a simple rule from the Unicode database content. Setting - * up a proper standard for the behavior of UTF-8 character terminals - * will require a careful analysis not only of each Unicode character, - * but also of each presentation form, something the author of these - * routines has avoided to do so far. - * - * http://www.unicode.org/unicode/reports/tr11/ - * - * Markus Kuhn -- 2007-05-26 (Unicode 5.0) - * - * Permission to use, copy, modify, and distribute this software - * for any purpose and without fee is hereby granted. The author - * disclaims all warranties with regard to this software. - * - * Latest version: http://www.cl.cam.ac.uk/~mgk25/ucs/wcwidth.c - */ - -struct interval { - int first; - int last; -}; - -/* auxiliary function for binary search in interval table */ -static inline int bisearch(wchar_t ucs, const struct interval* table, int max) { - int min = 0; - int mid; - - if (ucs < table[0].first || ucs > table[max].last) - return 0; - while (max >= min) { - mid = (min + max) / 2; - if (ucs > table[mid].last) - min = mid + 1; - else if (ucs < table[mid].first) - max = mid - 1; - else - return 1; - } - - return 0; -} - -/* The following two functions define the column width of an ISO 10646 - * character as follows: - * - * - The null character (U+0000) has a column width of 0. - * - * - Other C0/C1 control characters and DEL will lead to a return - * value of -1. - * - * - Non-spacing and enclosing combining characters (general - * category code Mn or Me in the Unicode database) have a - * column width of 0. - * - * - SOFT HYPHEN (U+00AD) has a column width of 1. - * - * - Other format characters (general category code Cf in the Unicode - * database) and ZERO WIDTH SPACE (U+200B) have a column width of 0. - * - * - Hangul Jamo medial vowels and final consonants (U+1160-U+11FF) - * have a column width of 0. - * - * - Spacing characters in the East Asian Wide (W) or East Asian - * Full-width (F) category as defined in Unicode Technical - * Report #11 have a column width of 2. - * - * - All remaining characters (including all printable - * ISO 8859-1 and WGL4 characters, Unicode control characters, - * etc.) have a column width of 1. - * - * This implementation assumes that wchar_t characters are encoded - * in ISO 10646. - */ - -static inline int mk_wcwidth(wchar_t ucs) { - /* sorted list of non-overlapping intervals of non-spacing characters */ - /* generated by "uniset +cat=Me +cat=Mn +cat=Cf -00AD +1160-11FF +200B c" */ - static const struct interval combining[] = { - {0x0300, 0x036F}, {0x0483, 0x0486}, {0x0488, 0x0489}, - {0x0591, 0x05BD}, {0x05BF, 0x05BF}, {0x05C1, 0x05C2}, - {0x05C4, 0x05C5}, {0x05C7, 0x05C7}, {0x0600, 0x0603}, - {0x0610, 0x0615}, {0x064B, 0x065E}, {0x0670, 0x0670}, - {0x06D6, 0x06E4}, {0x06E7, 0x06E8}, {0x06EA, 0x06ED}, - {0x070F, 0x070F}, {0x0711, 0x0711}, {0x0730, 0x074A}, - {0x07A6, 0x07B0}, {0x07EB, 0x07F3}, {0x0901, 0x0902}, - {0x093C, 0x093C}, {0x0941, 0x0948}, {0x094D, 0x094D}, - {0x0951, 0x0954}, {0x0962, 0x0963}, {0x0981, 0x0981}, - {0x09BC, 0x09BC}, {0x09C1, 0x09C4}, {0x09CD, 0x09CD}, - {0x09E2, 0x09E3}, {0x0A01, 0x0A02}, {0x0A3C, 0x0A3C}, - {0x0A41, 0x0A42}, {0x0A47, 0x0A48}, {0x0A4B, 0x0A4D}, - {0x0A70, 0x0A71}, {0x0A81, 0x0A82}, {0x0ABC, 0x0ABC}, - {0x0AC1, 0x0AC5}, {0x0AC7, 0x0AC8}, {0x0ACD, 0x0ACD}, - {0x0AE2, 0x0AE3}, {0x0B01, 0x0B01}, {0x0B3C, 0x0B3C}, - {0x0B3F, 0x0B3F}, {0x0B41, 0x0B43}, {0x0B4D, 0x0B4D}, - {0x0B56, 0x0B56}, {0x0B82, 0x0B82}, {0x0BC0, 0x0BC0}, - {0x0BCD, 0x0BCD}, {0x0C3E, 0x0C40}, {0x0C46, 0x0C48}, - {0x0C4A, 0x0C4D}, {0x0C55, 0x0C56}, {0x0CBC, 0x0CBC}, - {0x0CBF, 0x0CBF}, {0x0CC6, 0x0CC6}, {0x0CCC, 0x0CCD}, - {0x0CE2, 0x0CE3}, {0x0D41, 0x0D43}, {0x0D4D, 0x0D4D}, - {0x0DCA, 0x0DCA}, {0x0DD2, 0x0DD4}, {0x0DD6, 0x0DD6}, - {0x0E31, 0x0E31}, {0x0E34, 0x0E3A}, {0x0E47, 0x0E4E}, - {0x0EB1, 0x0EB1}, {0x0EB4, 0x0EB9}, {0x0EBB, 0x0EBC}, - {0x0EC8, 0x0ECD}, {0x0F18, 0x0F19}, {0x0F35, 0x0F35}, - {0x0F37, 0x0F37}, {0x0F39, 0x0F39}, {0x0F71, 0x0F7E}, - {0x0F80, 0x0F84}, {0x0F86, 0x0F87}, {0x0F90, 0x0F97}, - {0x0F99, 0x0FBC}, {0x0FC6, 0x0FC6}, {0x102D, 0x1030}, - {0x1032, 0x1032}, {0x1036, 0x1037}, {0x1039, 0x1039}, - {0x1058, 0x1059}, {0x1160, 0x11FF}, {0x135F, 0x135F}, - {0x1712, 0x1714}, {0x1732, 0x1734}, {0x1752, 0x1753}, - {0x1772, 0x1773}, {0x17B4, 0x17B5}, {0x17B7, 0x17BD}, - {0x17C6, 0x17C6}, {0x17C9, 0x17D3}, {0x17DD, 0x17DD}, - {0x180B, 0x180D}, {0x18A9, 0x18A9}, {0x1920, 0x1922}, - {0x1927, 0x1928}, {0x1932, 0x1932}, {0x1939, 0x193B}, - {0x1A17, 0x1A18}, {0x1B00, 0x1B03}, {0x1B34, 0x1B34}, - {0x1B36, 0x1B3A}, {0x1B3C, 0x1B3C}, {0x1B42, 0x1B42}, - {0x1B6B, 0x1B73}, {0x1DC0, 0x1DCA}, {0x1DFE, 0x1DFF}, - {0x200B, 0x200F}, {0x202A, 0x202E}, {0x2060, 0x2063}, - {0x206A, 0x206F}, {0x20D0, 0x20EF}, {0x302A, 0x302F}, - {0x3099, 0x309A}, {0xA806, 0xA806}, {0xA80B, 0xA80B}, - {0xA825, 0xA826}, {0xFB1E, 0xFB1E}, {0xFE00, 0xFE0F}, - {0xFE20, 0xFE23}, {0xFEFF, 0xFEFF}, {0xFFF9, 0xFFFB}, - {0x10A01, 0x10A03}, {0x10A05, 0x10A06}, {0x10A0C, 0x10A0F}, - {0x10A38, 0x10A3A}, {0x10A3F, 0x10A3F}, {0x1D167, 0x1D169}, - {0x1D173, 0x1D182}, {0x1D185, 0x1D18B}, {0x1D1AA, 0x1D1AD}, - {0x1D242, 0x1D244}, {0xE0001, 0xE0001}, {0xE0020, 0xE007F}, - {0xE0100, 0xE01EF}}; - - /* test for 8-bit control characters */ - if (ucs == 0) - return 0; - if (ucs < 32 || (ucs >= 0x7f && ucs < 0xa0)) - return -1; - - /* binary search in table of non-spacing characters */ - if (bisearch(ucs, combining, sizeof(combining) / sizeof(struct interval) - 1)) - return 0; - - /* if we arrive here, ucs is not a combining or C0/C1 control character */ - - return 1 + - (ucs >= 0x1100 && - (ucs <= 0x115f || /* Hangul Jamo init. consonants */ - ucs == 0x2329 || ucs == 0x232a || - (ucs >= 0x2e80 && ucs <= 0xa4cf && ucs != 0x303f) || /* CJK ... Yi */ - (ucs >= 0xac00 && ucs <= 0xd7a3) || /* Hangul Syllables */ - (ucs >= 0xf900 && - ucs <= 0xfaff) || /* CJK Compatibility Ideographs */ - (ucs >= 0xfe10 && ucs <= 0xfe19) || /* Vertical forms */ - (ucs >= 0xfe30 && ucs <= 0xfe6f) || /* CJK Compatibility Forms */ - (ucs >= 0xff00 && ucs <= 0xff60) || /* Fullwidth Forms */ - (ucs >= 0xffe0 && ucs <= 0xffe6) || - (ucs >= 0x20000 && ucs <= 0x2fffd) || - (ucs >= 0x30000 && ucs <= 0x3fffd))); -} - -static inline int mk_wcswidth(const wchar_t* pwcs, size_t n) { - int w, width = 0; - - for (; *pwcs && n-- > 0; pwcs++) - if ((w = mk_wcwidth(*pwcs)) < 0) - return -1; - else - width += w; - - return width; -} - -/* - * The following functions are the same as mk_wcwidth() and - * mk_wcswidth(), except that spacing characters in the East Asian - * Ambiguous (A) category as defined in Unicode Technical Report #11 - * have a column width of 2. This variant might be useful for users of - * CJK legacy encodings who want to migrate to UCS without changing - * the traditional terminal character-width behaviour. It is not - * otherwise recommended for general use. - */ -static inline int mk_wcwidth_cjk(wchar_t ucs) { - /* sorted list of non-overlapping intervals of East Asian Ambiguous - * characters, generated by "uniset +WIDTH-A -cat=Me -cat=Mn -cat=Cf c" */ - static const struct interval ambiguous[] = { - {0x00A1, 0x00A1}, {0x00A4, 0x00A4}, {0x00A7, 0x00A8}, - {0x00AA, 0x00AA}, {0x00AE, 0x00AE}, {0x00B0, 0x00B4}, - {0x00B6, 0x00BA}, {0x00BC, 0x00BF}, {0x00C6, 0x00C6}, - {0x00D0, 0x00D0}, {0x00D7, 0x00D8}, {0x00DE, 0x00E1}, - {0x00E6, 0x00E6}, {0x00E8, 0x00EA}, {0x00EC, 0x00ED}, - {0x00F0, 0x00F0}, {0x00F2, 0x00F3}, {0x00F7, 0x00FA}, - {0x00FC, 0x00FC}, {0x00FE, 0x00FE}, {0x0101, 0x0101}, - {0x0111, 0x0111}, {0x0113, 0x0113}, {0x011B, 0x011B}, - {0x0126, 0x0127}, {0x012B, 0x012B}, {0x0131, 0x0133}, - {0x0138, 0x0138}, {0x013F, 0x0142}, {0x0144, 0x0144}, - {0x0148, 0x014B}, {0x014D, 0x014D}, {0x0152, 0x0153}, - {0x0166, 0x0167}, {0x016B, 0x016B}, {0x01CE, 0x01CE}, - {0x01D0, 0x01D0}, {0x01D2, 0x01D2}, {0x01D4, 0x01D4}, - {0x01D6, 0x01D6}, {0x01D8, 0x01D8}, {0x01DA, 0x01DA}, - {0x01DC, 0x01DC}, {0x0251, 0x0251}, {0x0261, 0x0261}, - {0x02C4, 0x02C4}, {0x02C7, 0x02C7}, {0x02C9, 0x02CB}, - {0x02CD, 0x02CD}, {0x02D0, 0x02D0}, {0x02D8, 0x02DB}, - {0x02DD, 0x02DD}, {0x02DF, 0x02DF}, {0x0391, 0x03A1}, - {0x03A3, 0x03A9}, {0x03B1, 0x03C1}, {0x03C3, 0x03C9}, - {0x0401, 0x0401}, {0x0410, 0x044F}, {0x0451, 0x0451}, - {0x2010, 0x2010}, {0x2013, 0x2016}, {0x2018, 0x2019}, - {0x201C, 0x201D}, {0x2020, 0x2022}, {0x2024, 0x2027}, - {0x2030, 0x2030}, {0x2032, 0x2033}, {0x2035, 0x2035}, - {0x203B, 0x203B}, {0x203E, 0x203E}, {0x2074, 0x2074}, - {0x207F, 0x207F}, {0x2081, 0x2084}, {0x20AC, 0x20AC}, - {0x2103, 0x2103}, {0x2105, 0x2105}, {0x2109, 0x2109}, - {0x2113, 0x2113}, {0x2116, 0x2116}, {0x2121, 0x2122}, - {0x2126, 0x2126}, {0x212B, 0x212B}, {0x2153, 0x2154}, - {0x215B, 0x215E}, {0x2160, 0x216B}, {0x2170, 0x2179}, - {0x2190, 0x2199}, {0x21B8, 0x21B9}, {0x21D2, 0x21D2}, - {0x21D4, 0x21D4}, {0x21E7, 0x21E7}, {0x2200, 0x2200}, - {0x2202, 0x2203}, {0x2207, 0x2208}, {0x220B, 0x220B}, - {0x220F, 0x220F}, {0x2211, 0x2211}, {0x2215, 0x2215}, - {0x221A, 0x221A}, {0x221D, 0x2220}, {0x2223, 0x2223}, - {0x2225, 0x2225}, {0x2227, 0x222C}, {0x222E, 0x222E}, - {0x2234, 0x2237}, {0x223C, 0x223D}, {0x2248, 0x2248}, - {0x224C, 0x224C}, {0x2252, 0x2252}, {0x2260, 0x2261}, - {0x2264, 0x2267}, {0x226A, 0x226B}, {0x226E, 0x226F}, - {0x2282, 0x2283}, {0x2286, 0x2287}, {0x2295, 0x2295}, - {0x2299, 0x2299}, {0x22A5, 0x22A5}, {0x22BF, 0x22BF}, - {0x2312, 0x2312}, {0x2460, 0x24E9}, {0x24EB, 0x254B}, - {0x2550, 0x2573}, {0x2580, 0x258F}, {0x2592, 0x2595}, - {0x25A0, 0x25A1}, {0x25A3, 0x25A9}, {0x25B2, 0x25B3}, - {0x25B6, 0x25B7}, {0x25BC, 0x25BD}, {0x25C0, 0x25C1}, - {0x25C6, 0x25C8}, {0x25CB, 0x25CB}, {0x25CE, 0x25D1}, - {0x25E2, 0x25E5}, {0x25EF, 0x25EF}, {0x2605, 0x2606}, - {0x2609, 0x2609}, {0x260E, 0x260F}, {0x2614, 0x2615}, - {0x261C, 0x261C}, {0x261E, 0x261E}, {0x2640, 0x2640}, - {0x2642, 0x2642}, {0x2660, 0x2661}, {0x2663, 0x2665}, - {0x2667, 0x266A}, {0x266C, 0x266D}, {0x266F, 0x266F}, - {0x273D, 0x273D}, {0x2776, 0x277F}, {0xE000, 0xF8FF}, - {0xFFFD, 0xFFFD}, {0xF0000, 0xFFFFD}, {0x100000, 0x10FFFD}}; - - /* binary search in table of non-spacing characters */ - if (bisearch(ucs, ambiguous, sizeof(ambiguous) / sizeof(struct interval) - 1)) - return 2; - - return mk_wcwidth(ucs); -} - -static inline int mk_wcswidth_cjk(const wchar_t* pwcs, size_t n) { - int w, width = 0; - - for (; *pwcs && n-- > 0; pwcs++) - if ((w = mk_wcwidth_cjk(*pwcs)) < 0) - return -1; - else - width += w; - - return width; -} - -// convert UTF-8 string to wstring -#ifdef _MSC_VER -static inline std::wstring utf8_decode(const std::string& s) { - auto r = setlocale(LC_ALL, ""); - std::string curLocale; - if (r) - curLocale = r; - const char* _Source = s.c_str(); - size_t _Dsize = std::strlen(_Source) + 1; - wchar_t* _Dest = new wchar_t[_Dsize]; - size_t _Osize; - mbstowcs_s(&_Osize, _Dest, _Dsize, _Source, _Dsize); - std::wstring result = _Dest; - delete[] _Dest; - setlocale(LC_ALL, curLocale.c_str()); - return result; -} -#else -static inline std::wstring utf8_decode(const std::string& s) { - auto r = setlocale(LC_ALL, ""); - std::string curLocale; - if (r) - curLocale = r; - const char* _Source = s.c_str(); - size_t _Dsize = mbstowcs(NULL, _Source, 0) + 1; - wchar_t* _Dest = new wchar_t[_Dsize]; - wmemset(_Dest, 0, _Dsize); - mbstowcs(_Dest, _Source, _Dsize); - std::wstring result = _Dest; - delete[] _Dest; - setlocale(LC_ALL, curLocale.c_str()); - return result; -} -#endif - -} // namespace details - -static inline int display_width(const std::string& input) { - using namespace unicode::details; - return mk_wcswidth(utf8_decode(input).c_str(), input.size()); -} - -static inline int display_width(const std::wstring& input) { - return details::mk_wcswidth(input.c_str(), input.size()); -} - -} // namespace unicode - -#endif -// #include -// #include - -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace indicators { -namespace details { - -inline void set_stream_color(std::ostream& os, Color color) { - switch (color) { - case Color::grey: - os << termcolor::grey; - break; - case Color::red: - os << termcolor::red; - break; - case Color::green: - os << termcolor::green; - break; - case Color::yellow: - os << termcolor::yellow; - break; - case Color::blue: - os << termcolor::blue; - break; - case Color::magenta: - os << termcolor::magenta; - break; - case Color::cyan: - os << termcolor::cyan; - break; - case Color::white: - os << termcolor::white; - break; - default: - assert(false); - } -} - -inline void set_font_style(std::ostream& os, FontStyle style) { - switch (style) { - case FontStyle::bold: - os << termcolor::bold; - break; - case FontStyle::dark: - os << termcolor::dark; - break; - case FontStyle::italic: - os << termcolor::italic; - break; - case FontStyle::underline: - os << termcolor::underline; - break; - case FontStyle::blink: - os << termcolor::blink; - break; - case FontStyle::reverse: - os << termcolor::reverse; - break; - case FontStyle::concealed: - os << termcolor::concealed; - break; - case FontStyle::crossed: - os << termcolor::crossed; - break; - default: - break; - } -} - -inline std::ostream& write_duration(std::ostream& os, - std::chrono::nanoseconds ns) { - using namespace std; - using namespace std::chrono; - using days = duration>; - char fill = os.fill(); - os.fill('0'); - auto d = duration_cast(ns); - ns -= d; - auto h = duration_cast(ns); - ns -= h; - auto m = duration_cast(ns); - ns -= m; - auto s = duration_cast(ns); - if (d.count() > 0) - os << setw(2) << d.count() << "d:"; - if (h.count() > 0) - os << setw(2) << h.count() << "h:"; - os << setw(2) << m.count() << "m:" << setw(2) << s.count() << 's'; - os.fill(fill); - return os; -} - -class BlockProgressScaleWriter { - public: - BlockProgressScaleWriter(std::ostream& os, size_t bar_width) - : os(os), bar_width(bar_width) {} - - std::ostream& write(float progress) { - std::string fill_text{"█"}; - std::vector lead_characters{" ", "▏", "▎", "▍", - "▌", "▋", "▊", "▉"}; - auto value = (std::min)(1.0f, (std::max)(0.0f, progress / 100.0f)); - auto whole_width = std::floor(value * bar_width); - auto remainder_width = fmod((value * bar_width), 1.0f); - auto part_width = std::floor(remainder_width * lead_characters.size()); - std::string lead_text = lead_characters[size_t(part_width)]; - if ((bar_width - whole_width - 1) < 0) - lead_text = ""; - for (size_t i = 0; i < whole_width; ++i) - os << fill_text; - os << lead_text; - for (size_t i = 0; i < (bar_width - whole_width - 1); ++i) - os << " "; - return os; - } - - private: - std::ostream& os; - size_t bar_width = 0; -}; - -class ProgressScaleWriter { - public: - ProgressScaleWriter(std::ostream& os, size_t bar_width, - const std::string& fill, const std::string& lead, - const std::string& remainder) - : os(os), - bar_width(bar_width), - fill(fill), - lead(lead), - remainder(remainder) {} - - std::ostream& write(float progress) { - auto pos = static_cast(progress * bar_width / 100.0); - for (size_t i = 0, current_display_width = 0; i < bar_width;) { - std::string next; - - if (i < pos) { - next = fill; - current_display_width = unicode::display_width(fill); - } else if (i == pos) { - next = lead; - current_display_width = unicode::display_width(lead); - } else { - next = remainder; - current_display_width = unicode::display_width(remainder); - } - - i += current_display_width; - - if (i > bar_width) { - // `next` is larger than the allowed bar width - // fill with empty space instead - os << std::string((bar_width - (i - current_display_width)), ' '); - break; - } - - os << next; - } - return os; - } - - private: - std::ostream& os; - size_t bar_width = 0; - std::string fill; - std::string lead; - std::string remainder; -}; - -class IndeterminateProgressScaleWriter { - public: - IndeterminateProgressScaleWriter(std::ostream& os, size_t bar_width, - const std::string& fill, - const std::string& lead) - : os(os), bar_width(bar_width), fill(fill), lead(lead) {} - - std::ostream& write(size_t progress) { - for (size_t i = 0; i < bar_width;) { - std::string next; - size_t current_display_width = 0; - - if (i < progress) { - next = fill; - current_display_width = unicode::display_width(fill); - } else if (i == progress) { - next = lead; - current_display_width = unicode::display_width(lead); - } else { - next = fill; - current_display_width = unicode::display_width(fill); - } - - i += current_display_width; - - if (i > bar_width) { - // `next` is larger than the allowed bar width - // fill with empty space instead - os << std::string((bar_width - (i - current_display_width)), ' '); - break; - } - - os << next; - } - return os; - } - - private: - std::ostream& os; - size_t bar_width = 0; - std::string fill; - std::string lead; -}; - -} // namespace details -} // namespace indicators - -#endif - -#ifndef INDICATORS_PROGRESS_BAR -#define INDICATORS_PROGRESS_BAR - -// #include - -#include -#include -#include -#include -// #include -// #include -// #include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace indicators { - -class ProgressBar { - using Settings = - std::tuple; - - public: - template ::type...>::value, - void*>::type = nullptr> - explicit ProgressBar(Args&&... args) - : settings_( - details::get( - option::BarWidth{100}, std::forward(args)...), - details::get( - option::PrefixText{}, std::forward(args)...), - details::get( - option::PostfixText{}, std::forward(args)...), - details::get( - option::Start{"["}, std::forward(args)...), - details::get( - option::End{"]"}, std::forward(args)...), - details::get( - option::Fill{"="}, std::forward(args)...), - details::get( - option::Lead{">"}, std::forward(args)...), - details::get( - option::Remainder{" "}, std::forward(args)...), - details::get( - option::MaxPostfixTextLen{0}, std::forward(args)...), - details::get( - option::Completed{false}, std::forward(args)...), - details::get( - option::ShowPercentage{false}, std::forward(args)...), - details::get( - option::ShowElapsedTime{false}, std::forward(args)...), - details::get( - option::ShowRemainingTime{false}, std::forward(args)...), - details::get( - option::SavedStartTime{false}, std::forward(args)...), - details::get( - option::ForegroundColor{Color::unspecified}, - std::forward(args)...), - details::get( - option::FontStyles{std::vector{}}, - std::forward(args)...), - details::get( - option::MinProgress{0}, std::forward(args)...), - details::get( - option::MaxProgress{100}, std::forward(args)...), - details::get( - option::ProgressType{ProgressType::incremental}, - std::forward(args)...), - details::get( - option::Stream{std::cout}, std::forward(args)...)) { - - // if progress is incremental, start from min_progress - // else start from max_progress - const auto type = get_value(); - if (type == ProgressType::incremental) - progress_ = get_value(); - else - progress_ = get_value(); - } - - template - void set_option(details::Setting&& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = std::move(setting).value; - } - - template - void set_option(const details::Setting& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = setting.value; - } - - void set_option( - const details::Setting< - std::string, details::ProgressBarOption::postfix_text>& setting) { - std::lock_guard lock(mutex_); - get_value() = setting.value; - if (setting.value.length() > - get_value()) { - get_value() = - setting.value.length(); - } - } - - void set_option( - details::Setting&& - setting) { - std::lock_guard lock(mutex_); - get_value() = - std::move(setting).value; - auto& new_value = get_value(); - if (new_value.length() > - get_value()) { - get_value() = - new_value.length(); - } - } - - void set_progress(size_t new_progress) { - { - std::lock_guard lock(mutex_); - progress_ = new_progress; - } - - save_start_time(); - print_progress(); - } - - void tick() { - { - std::lock_guard lock{mutex_}; - const auto type = get_value(); - if (type == ProgressType::incremental) - progress_ += 1; - else - progress_ -= 1; - } - save_start_time(); - print_progress(); - } - - size_t current() { - std::lock_guard lock{mutex_}; - return (std::min)( - progress_, - size_t(get_value())); - } - - bool is_completed() const { - return get_value(); - } - - void mark_as_completed() { - get_value() = true; - print_progress(); - } - - private: - template - auto get_value() - -> decltype((details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - template - auto get_value() const - -> decltype(( - details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - size_t progress_{0}; - Settings settings_; - std::chrono::nanoseconds elapsed_; - std::chrono::time_point start_time_point_; - std::mutex mutex_; - - template - friend class MultiProgress; - template - friend class DynamicProgress; - std::atomic multi_progress_mode_{false}; - - void save_start_time() { - auto& show_elapsed_time = - get_value(); - auto& saved_start_time = - get_value(); - auto& show_remaining_time = - get_value(); - if ((show_elapsed_time || show_remaining_time) && !saved_start_time) { - start_time_point_ = std::chrono::high_resolution_clock::now(); - saved_start_time = true; - } - } - - std::pair get_prefix_text() { - std::stringstream os; - os << get_value(); - const auto result = os.str(); - const auto result_size = unicode::display_width(result); - return {result, result_size}; - } - - std::pair get_postfix_text() { - std::stringstream os; - const auto max_progress = - get_value(); - - if (get_value()) { - os << " " - << (std::min)(static_cast(static_cast(progress_) / - max_progress * 100), - size_t(100)) - << "%"; - } - - auto& saved_start_time = - get_value(); - - if (get_value()) { - os << " ["; - if (saved_start_time) - details::write_duration(os, elapsed_); - else - os << "00:00s"; - } - - if (get_value()) { - if (get_value()) - os << "<"; - else - os << " ["; - - if (saved_start_time) { - auto eta = std::chrono::nanoseconds( - progress_ > 0 - ? static_cast(std::ceil(float(elapsed_.count()) * - max_progress / progress_)) - : 0); - auto remaining = eta > elapsed_ ? (eta - elapsed_) : (elapsed_ - eta); - details::write_duration(os, remaining); - } else { - os << "00:00s"; - } - - os << "]"; - } else { - if (get_value()) - os << "]"; - } - - os << " " << get_value(); - - const auto result = os.str(); - const auto result_size = unicode::display_width(result); - return {result, result_size}; - } - - public: - void print_progress(bool from_multi_progress = false) { - std::lock_guard lock{mutex_}; - - auto& os = get_value(); - - const auto type = get_value(); - const auto min_progress = - get_value(); - const auto max_progress = - get_value(); - if (multi_progress_mode_ && !from_multi_progress) { - if ((type == ProgressType::incremental && progress_ >= max_progress) || - (type == ProgressType::decremental && progress_ <= min_progress)) { - get_value() = true; - } - return; - } - auto now = std::chrono::high_resolution_clock::now(); - if (!get_value()) - elapsed_ = std::chrono::duration_cast( - now - start_time_point_); - - if (get_value() != - Color::unspecified) - details::set_stream_color( - os, get_value()); - - for (auto& style : get_value()) - details::set_font_style(os, style); - - const auto prefix_pair = get_prefix_text(); - const auto prefix_text = prefix_pair.first; - const auto prefix_length = prefix_pair.second; - os << "\r" << prefix_text; - - os << get_value(); - - details::ProgressScaleWriter writer{ - os, get_value(), - get_value(), - get_value(), - get_value()}; - writer.write(double(progress_) / double(max_progress) * 100.0f); - - os << get_value(); - - const auto postfix_pair = get_postfix_text(); - const auto postfix_text = postfix_pair.first; - const auto postfix_length = postfix_pair.second; - os << postfix_text; - - // Get length of prefix text and postfix text - const auto start_length = - get_value().size(); - const auto bar_width = get_value(); - const auto end_length = get_value().size(); - const auto terminal_width = terminal_size().second; - // prefix + bar_width + postfix should be <= terminal_width - const int remaining = - terminal_width - (prefix_length + start_length + bar_width + - end_length + postfix_length); - if (prefix_length == -1 || postfix_length == -1) { - os << "\r"; - } else if (remaining > 0) { - os << std::string(remaining, ' ') << "\r"; - } else if (remaining < 0) { - // Do nothing. Maybe in the future truncate postfix with ... - } - os.flush(); - - if ((type == ProgressType::incremental && progress_ >= max_progress) || - (type == ProgressType::decremental && progress_ <= min_progress)) { - get_value() = true; - } - if (get_value() && - !from_multi_progress) // Don't std::endl if calling from MultiProgress - os << termcolor::reset << std::endl; - } -}; - -} // namespace indicators - -#endif - -#ifndef INDICATORS_BLOCK_PROGRESS_BAR -#define INDICATORS_BLOCK_PROGRESS_BAR - -// #include -// #include - -#include -#include -#include -// #include -// #include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace indicators { - -class BlockProgressBar { - using Settings = - std::tuple; - - public: - template ::type...>::value, - void*>::type = nullptr> - explicit BlockProgressBar(Args&&... args) - : settings_( - details::get( - option::ForegroundColor{Color::unspecified}, - std::forward(args)...), - details::get( - option::BarWidth{100}, std::forward(args)...), - details::get( - option::Start{"["}, std::forward(args)...), - details::get( - option::End{"]"}, std::forward(args)...), - details::get( - option::PrefixText{""}, std::forward(args)...), - details::get( - option::PostfixText{""}, std::forward(args)...), - details::get( - option::ShowPercentage{true}, std::forward(args)...), - details::get( - option::ShowElapsedTime{false}, std::forward(args)...), - details::get( - option::ShowRemainingTime{false}, std::forward(args)...), - details::get( - option::Completed{false}, std::forward(args)...), - details::get( - option::SavedStartTime{false}, std::forward(args)...), - details::get( - option::MaxPostfixTextLen{0}, std::forward(args)...), - details::get( - option::FontStyles{std::vector{}}, - std::forward(args)...), - details::get( - option::MaxProgress{100}, std::forward(args)...), - details::get( - option::Stream{std::cout}, std::forward(args)...)) {} - - template - void set_option(details::Setting&& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = std::move(setting).value; - } - - template - void set_option(const details::Setting& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = setting.value; - } - - void set_option( - const details::Setting< - std::string, details::ProgressBarOption::postfix_text>& setting) { - std::lock_guard lock(mutex_); - get_value() = setting.value; - if (setting.value.length() > - get_value()) { - get_value() = - setting.value.length(); - } - } - - void set_option( - details::Setting&& - setting) { - std::lock_guard lock(mutex_); - get_value() = - std::move(setting).value; - auto& new_value = get_value(); - if (new_value.length() > - get_value()) { - get_value() = - new_value.length(); - } - } - - void set_progress(float value) { - { - std::lock_guard lock{mutex_}; - progress_ = value; - } - save_start_time(); - print_progress(); - } - - void tick() { - { - std::lock_guard lock{mutex_}; - progress_ += 1; - } - save_start_time(); - print_progress(); - } - - size_t current() { - std::lock_guard lock{mutex_}; - return (std::min)( - static_cast(progress_), - size_t(get_value())); - } - - bool is_completed() const { - return get_value(); - } - - void mark_as_completed() { - get_value() = true; - print_progress(); - } - - private: - template - auto get_value() - -> decltype((details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - template - auto get_value() const - -> decltype(( - details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - Settings settings_; - float progress_{0.0}; - std::chrono::time_point start_time_point_; - std::mutex mutex_; - - template - friend class MultiProgress; - template - friend class DynamicProgress; - std::atomic multi_progress_mode_{false}; - - void save_start_time() { - auto& show_elapsed_time = - get_value(); - auto& saved_start_time = - get_value(); - auto& show_remaining_time = - get_value(); - if ((show_elapsed_time || show_remaining_time) && !saved_start_time) { - start_time_point_ = std::chrono::high_resolution_clock::now(); - saved_start_time = true; - } - } - - std::pair get_prefix_text() { - std::stringstream os; - os << get_value(); - const auto result = os.str(); - const auto result_size = unicode::display_width(result); - return {result, result_size}; - } - - std::pair get_postfix_text() { - std::stringstream os; - const auto max_progress = - get_value(); - auto now = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast( - now - start_time_point_); - - if (get_value()) { - os << " " - << (std::min)(static_cast(progress_ / max_progress * 100.0), - size_t(100)) - << "%"; - } - - auto& saved_start_time = - get_value(); - - if (get_value()) { - os << " ["; - if (saved_start_time) - details::write_duration(os, elapsed); - else - os << "00:00s"; - } - - if (get_value()) { - if (get_value()) - os << "<"; - else - os << " ["; - - if (saved_start_time) { - auto eta = std::chrono::nanoseconds( - progress_ > 0 - ? static_cast(std::ceil(float(elapsed.count()) * - max_progress / progress_)) - : 0); - auto remaining = eta > elapsed ? (eta - elapsed) : (elapsed - eta); - details::write_duration(os, remaining); - } else { - os << "00:00s"; - } - - os << "]"; - } else { - if (get_value()) - os << "]"; - } - - os << " " << get_value(); - - const auto result = os.str(); - const auto result_size = unicode::display_width(result); - return {result, result_size}; - } - - public: - void print_progress(bool from_multi_progress = false) { - std::lock_guard lock{mutex_}; - - auto& os = get_value(); - - const auto max_progress = - get_value(); - if (multi_progress_mode_ && !from_multi_progress) { - if (progress_ > max_progress) { - get_value() = true; - } - return; - } - - if (get_value() != - Color::unspecified) - details::set_stream_color( - os, get_value()); - - for (auto& style : get_value()) - details::set_font_style(os, style); - - const auto prefix_pair = get_prefix_text(); - const auto prefix_text = prefix_pair.first; - const auto prefix_length = prefix_pair.second; - os << "\r" << prefix_text; - - os << get_value(); - - details::BlockProgressScaleWriter writer{ - os, get_value()}; - writer.write(progress_ / max_progress * 100); - - os << get_value(); - - const auto postfix_pair = get_postfix_text(); - const auto postfix_text = postfix_pair.first; - const auto postfix_length = postfix_pair.second; - os << postfix_text; - - // Get length of prefix text and postfix text - const auto start_length = - get_value().size(); - const auto bar_width = get_value(); - const auto end_length = get_value().size(); - const auto terminal_width = terminal_size().second; - // prefix + bar_width + postfix should be <= terminal_width - const int remaining = - terminal_width - (prefix_length + start_length + bar_width + - end_length + postfix_length); - if (prefix_length == -1 || postfix_length == -1) { - os << "\r"; - } else if (remaining > 0) { - os << std::string(remaining, ' ') << "\r"; - } else if (remaining < 0) { - // Do nothing. Maybe in the future truncate postfix with ... - } - os.flush(); - - if (progress_ > max_progress) { - get_value() = true; - } - if (get_value() && - !from_multi_progress) // Don't std::endl if calling from MultiProgress - os << termcolor::reset << std::endl; - } -}; - -} // namespace indicators - -#endif - -#ifndef INDICATORS_INDETERMINATE_PROGRESS_BAR -#define INDICATORS_INDETERMINATE_PROGRESS_BAR - -// #include - -#include -#include -#include -#include -// #include -// #include -// #include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace indicators { - -class IndeterminateProgressBar { - using Settings = - std::tuple; - - enum class Direction { forward, backward }; - - Direction direction_{Direction::forward}; - - public: - template ::type...>::value, - void*>::type = nullptr> - explicit IndeterminateProgressBar(Args&&... args) - : settings_( - details::get( - option::BarWidth{100}, std::forward(args)...), - details::get( - option::PrefixText{}, std::forward(args)...), - details::get( - option::PostfixText{}, std::forward(args)...), - details::get( - option::Start{"["}, std::forward(args)...), - details::get( - option::End{"]"}, std::forward(args)...), - details::get( - option::Fill{"."}, std::forward(args)...), - details::get( - option::Lead{"<==>"}, std::forward(args)...), - details::get( - option::MaxPostfixTextLen{0}, std::forward(args)...), - details::get( - option::Completed{false}, std::forward(args)...), - details::get( - option::ForegroundColor{Color::unspecified}, - std::forward(args)...), - details::get( - option::FontStyles{std::vector{}}, - std::forward(args)...), - details::get( - option::Stream{std::cout}, std::forward(args)...)) { - // starts with [<==>...........] - // progress_ = 0 - - // ends with [...........<==>] - // ^^^^^^^^^^^^^^^^^ bar_width - // ^^^^^^^^^^^^ (bar_width - len(lead)) - // progress_ = bar_width - len(lead) - progress_ = 0; - max_progress_ = get_value() - - get_value().size() + - get_value().size() + - get_value().size(); - } - - template - void set_option(details::Setting&& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = std::move(setting).value; - } - - template - void set_option(const details::Setting& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = setting.value; - } - - void set_option( - const details::Setting< - std::string, details::ProgressBarOption::postfix_text>& setting) { - std::lock_guard lock(mutex_); - get_value() = setting.value; - if (setting.value.length() > - get_value()) { - get_value() = - setting.value.length(); - } - } - - void set_option( - details::Setting&& - setting) { - std::lock_guard lock(mutex_); - get_value() = - std::move(setting).value; - auto& new_value = get_value(); - if (new_value.length() > - get_value()) { - get_value() = - new_value.length(); - } - } - - void tick() { - { - std::lock_guard lock{mutex_}; - if (get_value()) - return; - - progress_ += (direction_ == Direction::forward) ? 1 : -1; - if (direction_ == Direction::forward && progress_ == max_progress_) { - // time to go back - direction_ = Direction::backward; - } else if (direction_ == Direction::backward && progress_ == 0) { - direction_ = Direction::forward; - } - } - print_progress(); - } - - bool is_completed() { - return get_value(); - } - - void mark_as_completed() { - get_value() = true; - print_progress(); - } - - private: - template - auto get_value() - -> decltype((details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - template - auto get_value() const - -> decltype(( - details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - size_t progress_{0}; - size_t max_progress_; - Settings settings_; - std::chrono::nanoseconds elapsed_; - std::mutex mutex_; - - template - friend class MultiProgress; - template - friend class DynamicProgress; - std::atomic multi_progress_mode_{false}; - - std::pair get_prefix_text() { - std::stringstream os; - os << get_value(); - const auto result = os.str(); - const auto result_size = unicode::display_width(result); - return {result, result_size}; - } - - std::pair get_postfix_text() { - std::stringstream os; - os << " " << get_value(); - - const auto result = os.str(); - const auto result_size = unicode::display_width(result); - return {result, result_size}; - } - - public: - void print_progress(bool from_multi_progress = false) { - std::lock_guard lock{mutex_}; - - auto& os = get_value(); - - if (multi_progress_mode_ && !from_multi_progress) { - return; - } - if (get_value() != - Color::unspecified) - details::set_stream_color( - os, get_value()); - - for (auto& style : get_value()) - details::set_font_style(os, style); - - const auto prefix_pair = get_prefix_text(); - const auto prefix_text = prefix_pair.first; - const auto prefix_length = prefix_pair.second; - os << "\r" << prefix_text; - - os << get_value(); - - details::IndeterminateProgressScaleWriter writer{ - os, get_value(), - get_value(), - get_value()}; - writer.write(progress_); - - os << get_value(); - - const auto postfix_pair = get_postfix_text(); - const auto postfix_text = postfix_pair.first; - const auto postfix_length = postfix_pair.second; - os << postfix_text; - - // Get length of prefix text and postfix text - const auto start_length = - get_value().size(); - const auto bar_width = get_value(); - const auto end_length = get_value().size(); - const auto terminal_width = terminal_size().second; - // prefix + bar_width + postfix should be <= terminal_width - const int remaining = - terminal_width - (prefix_length + start_length + bar_width + - end_length + postfix_length); - if (prefix_length == -1 || postfix_length == -1) { - os << "\r"; - } else if (remaining > 0) { - os << std::string(remaining, ' ') << "\r"; - } else if (remaining < 0) { - // Do nothing. Maybe in the future truncate postfix with ... - } - os.flush(); - - if (get_value() && - !from_multi_progress) // Don't std::endl if calling from MultiProgress - os << termcolor::reset << std::endl; - } -}; - -} // namespace indicators - -#endif - -#ifndef INDICATORS_MULTI_PROGRESS -#define INDICATORS_MULTI_PROGRESS -#include -#include -#include -#include -#include - -// #include -// #include -// #include - -namespace indicators { - -template -class MultiProgress { - public: - template ::type> - explicit MultiProgress(Indicators&... bars) { - bars_ = {bars...}; - for (auto& bar : bars_) { - bar.get().multi_progress_mode_ = true; - } - } - - template - typename std::enable_if<(index >= 0 && index < count), void>::type - set_progress(size_t value) { - if (!bars_[index].get().is_completed()) - bars_[index].get().set_progress(value); - print_progress(); - } - - template - typename std::enable_if<(index >= 0 && index < count), void>::type - set_progress(float value) { - if (!bars_[index].get().is_completed()) - bars_[index].get().set_progress(value); - print_progress(); - } - - template - typename std::enable_if<(index >= 0 && index < count), void>::type tick() { - if (!bars_[index].get().is_completed()) - bars_[index].get().tick(); - print_progress(); - } - - template - typename std::enable_if<(index >= 0 && index < count), bool>::type - is_completed() const { - return bars_[index].get().is_completed(); - } - - private: - std::atomic started_{false}; - std::mutex mutex_; - std::vector> bars_; - - bool _all_completed() { - bool result{true}; - for (size_t i = 0; i < count; ++i) - result &= bars_[i].get().is_completed(); - return result; - } - - public: - void print_progress() { - std::lock_guard lock{mutex_}; - if (started_) - move_up(count); - for (auto& bar : bars_) { - bar.get().print_progress(true); - std::cout << "\n"; - } - std::cout << termcolor::reset; - if (!started_) - started_ = true; - } -}; - -} // namespace indicators - -#endif - -#ifndef INDICATORS_DYNAMIC_PROGRESS -#define INDICATORS_DYNAMIC_PROGRESS - -#include -#include -// #include -// #include -// #include -// #include -// #include -#include -#include -#include - -namespace indicators { - -template -class DynamicProgress { - using Settings = std::tuple; - - public: - template - explicit DynamicProgress(Indicators&... bars) { - bars_ = {bars...}; - for (auto& bar : bars_) { - bar.get().multi_progress_mode_ = true; - ++total_count_; - ++incomplete_count_; - } - } - - Indicator& operator[](size_t index) { - print_progress(); - std::lock_guard lock{mutex_}; - return bars_[index].get(); - } - - size_t push_back(Indicator& bar) { - std::lock_guard lock{mutex_}; - bar.multi_progress_mode_ = true; - bars_.push_back(bar); - return bars_.size() - 1; - } - - template - void set_option(details::Setting&& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = std::move(setting).value; - } - - template - void set_option(const details::Setting& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = setting.value; - } - - private: - Settings settings_; - std::atomic started_{false}; - std::mutex mutex_; - std::vector> bars_; - std::atomic total_count_{0}; - std::atomic incomplete_count_{0}; - - template - auto get_value() - -> decltype((details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - template - auto get_value() const - -> decltype(( - details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - public: - void print_progress() { - std::lock_guard lock{mutex_}; - auto& hide_bar_when_complete = - get_value(); - if (hide_bar_when_complete) { - // Hide completed bars - if (started_) { - for (size_t i = 0; i < incomplete_count_; ++i) { - move_up(1); - erase_line(); - std::cout << std::flush; - } - } - incomplete_count_ = 0; - for (auto& bar : bars_) { - if (!bar.get().is_completed()) { - bar.get().print_progress(true); - std::cout << "\n"; - ++incomplete_count_; - } - } - if (!started_) - started_ = true; - } else { - // Don't hide any bars - if (started_) - move_up(static_cast(total_count_)); - for (auto& bar : bars_) { - bar.get().print_progress(true); - std::cout << "\n"; - } - if (!started_) - started_ = true; - } - total_count_ = bars_.size(); - std::cout << termcolor::reset; - } -}; - -} // namespace indicators - -#endif - -#ifndef INDICATORS_PROGRESS_SPINNER -#define INDICATORS_PROGRESS_SPINNER - -// #include - -#include -#include -#include -#include -// #include -// #include -#include -#include -#include -#include -#include -#include -#include - -namespace indicators { - -class ProgressSpinner { - using Settings = - std::tuple; - - public: - template ::type...>::value, - void*>::type = nullptr> - explicit ProgressSpinner(Args&&... args) - : settings_( - details::get( - option::ForegroundColor{Color::unspecified}, - std::forward(args)...), - details::get( - option::PrefixText{}, std::forward(args)...), - details::get( - option::PostfixText{}, std::forward(args)...), - details::get( - option::ShowPercentage{true}, std::forward(args)...), - details::get( - option::ShowElapsedTime{false}, std::forward(args)...), - details::get( - option::ShowRemainingTime{false}, std::forward(args)...), - details::get( - option::ShowSpinner{true}, std::forward(args)...), - details::get( - option::SavedStartTime{false}, std::forward(args)...), - details::get( - option::Completed{false}, std::forward(args)...), - details::get( - option::MaxPostfixTextLen{0}, std::forward(args)...), - details::get( - option::SpinnerStates{std::vector{ - "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}}, - std::forward(args)...), - details::get( - option::FontStyles{std::vector{}}, - std::forward(args)...), - details::get( - option::MaxProgress{100}, std::forward(args)...), - details::get( - option::Stream{std::cout}, std::forward(args)...)) {} - - template - void set_option(details::Setting&& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = std::move(setting).value; - } - - template - void set_option(const details::Setting& setting) { - static_assert( - !std::is_same( - std::declval()))>::type>::value, - "Setting has wrong type!"); - std::lock_guard lock(mutex_); - get_value() = setting.value; - } - - void set_option( - const details::Setting< - std::string, details::ProgressBarOption::postfix_text>& setting) { - std::lock_guard lock(mutex_); - get_value() = setting.value; - if (setting.value.length() > - get_value()) { - get_value() = - setting.value.length(); - } - } - - void set_option( - details::Setting&& - setting) { - std::lock_guard lock(mutex_); - get_value() = - std::move(setting).value; - auto& new_value = get_value(); - if (new_value.length() > - get_value()) { - get_value() = - new_value.length(); - } - } - - void set_progress(size_t value) { - { - std::lock_guard lock{mutex_}; - progress_ = value; - } - save_start_time(); - print_progress(); - } - - void tick() { - { - std::lock_guard lock{mutex_}; - progress_ += 1; - } - save_start_time(); - print_progress(); - } - - size_t current() { - std::lock_guard lock{mutex_}; - return (std::min)( - progress_, - size_t(get_value())); - } - - bool is_completed() const { - return get_value(); - } - - void mark_as_completed() { - get_value() = true; - print_progress(); - } - - private: - Settings settings_; - size_t progress_{0}; - size_t index_{0}; - std::chrono::time_point start_time_point_; - std::mutex mutex_; - - template - auto get_value() - -> decltype((details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - template - auto get_value() const - -> decltype(( - details::get_value(std::declval()).value)) { - return details::get_value(settings_).value; - } - - void save_start_time() { - auto& show_elapsed_time = - get_value(); - auto& show_remaining_time = - get_value(); - auto& saved_start_time = - get_value(); - if ((show_elapsed_time || show_remaining_time) && !saved_start_time) { - start_time_point_ = std::chrono::high_resolution_clock::now(); - saved_start_time = true; - } - } - - public: - void print_progress() { - std::lock_guard lock{mutex_}; - - auto& os = get_value(); - - const auto max_progress = - get_value(); - auto now = std::chrono::high_resolution_clock::now(); - auto elapsed = std::chrono::duration_cast( - now - start_time_point_); - - if (get_value() != - Color::unspecified) - details::set_stream_color( - os, get_value()); - - for (auto& style : get_value()) - details::set_font_style(os, style); - - os << get_value(); - if (get_value()) - os << get_value() - [index_ % - get_value().size()]; - if (get_value()) { - os << " " << std::size_t(progress_ / double(max_progress) * 100) << "%"; - } - - if (get_value()) { - os << " ["; - details::write_duration(os, elapsed); - } - - if (get_value()) { - if (get_value()) - os << "<"; - else - os << " ["; - auto eta = std::chrono::nanoseconds( - progress_ > 0 - ? static_cast(std::ceil(float(elapsed.count()) * - max_progress / progress_)) - : 0); - auto remaining = eta > elapsed ? (eta - elapsed) : (elapsed - eta); - details::write_duration(os, remaining); - os << "]"; - } else { - if (get_value()) - os << "]"; - } - - if (get_value() == 0) - get_value() = 10; - os << " " << get_value() - << std::string( - get_value(), - ' ') - << "\r"; - os.flush(); - index_ += 1; - if (progress_ > max_progress) { - get_value() = true; - } - if (get_value()) - os << termcolor::reset << std::endl; - } -}; - -} // namespace indicators - -#endif diff --git a/engine/vcpkg.json b/engine/vcpkg.json index cfab8d2f5..1f8d31bcc 100644 --- a/engine/vcpkg.json +++ b/engine/vcpkg.json @@ -16,6 +16,7 @@ "tabulate", "eventpp", "sqlitecpp", - "trantor" + "trantor", + "indicators" ] } From e7aabe78546085fccf08ca62582fd890123a254e Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Tue, 29 Oct 2024 10:39:35 +0700 Subject: [PATCH 24/24] fix: download progress --- engine/cli/utils/download_progress.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index d7c48d3a6..2613fe413 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -52,8 +52,6 @@ bool DownloadProgress::Handle(const std::string& id) { return; } - status_ = ev.type_; - if (!bars) { bars = std::make_unique< indicators::DynamicProgress>(); @@ -72,7 +70,7 @@ bool DownloadProgress::Handle(const std::string& id) { auto& it = ev.download_task_.items[i]; uint64_t downloaded = it.downloadedBytes.value_or(0); uint64_t total = it.bytes.value_or(9999); - if (status_ == DownloadStatus::DownloadUpdated) { + if (ev.type_ == DownloadStatus::DownloadUpdated) { (*bars)[i].set_option(indicators::option::PrefixText{ pad_string(it.id) + std::to_string( @@ -83,18 +81,20 @@ bool DownloadProgress::Handle(const std::string& id) { (*bars)[i].set_option(indicators::option::PostfixText{ format_utils::BytesToHumanReadable(downloaded) + "/" + format_utils::BytesToHumanReadable(total)}); - } else if (status_ == DownloadStatus::DownloadSuccess) { - (*bars)[i].set_option( - indicators::option::PrefixText{pad_string(it.id) + "100%"}); + } else if (ev.type_ == DownloadStatus::DownloadSuccess) { (*bars)[i].set_progress(100); auto total_str = format_utils::BytesToHumanReadable(total); (*bars)[i].set_option( indicators::option::PostfixText{total_str + "/" + total_str}); + (*bars)[i].set_option( + indicators::option::PrefixText{pad_string(it.id) + "100%"}); + (*bars)[i].set_progress(100); CTL_INF("Download success"); } } } + status_ = ev.type_; }; while (ws_->getReadyState() != easywsclient::WebSocket::CLOSED &&