diff --git a/common/arg.h b/common/arg.h index 7ab7e2cea43..7fc8cb3c580 100644 --- a/common/arg.h +++ b/common/arg.h @@ -78,11 +78,3 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); - -struct common_remote_params { - std::vector headers; - long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout - long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB -}; -// get remote file content, returns -std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); diff --git a/common/download.cpp b/common/download.cpp index eeb32b6a863..76ca20eae58 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -303,7 +303,8 @@ static bool common_download_head(CURL * curl, // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; for (int i = 0; i < max_attempts; ++i) { @@ -325,6 +326,11 @@ static bool common_download_file_single_online(const std::string & url, common_load_model_from_url_headers headers; curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); curl_slist_ptr http_headers; + + for (const auto & h : custom_headers) { + std::string s = h.first + ": " + h.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str()); + } const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token); if (!was_perform_successful) { head_request_ok = false; @@ -449,8 +455,10 @@ std::pair> common_remote_get_content(const std::string & curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); } http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { - http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + std::string header_ = header.first + ": " + header.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str()); } curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); @@ -562,19 +570,69 @@ static bool common_pull_file(httplib::Client & cli, return true; } +static void common_set_clean_host_header(httplib::Headers & headers, const std::string & host) { + if (headers.count("Host")) { + headers.erase("Host"); + } + + std::string clean_host = host; + size_t pos = clean_host.find(':'); + if (pos != std::string::npos) { + clean_host = clean_host.substr(0, pos); + } + + headers.emplace("Host", clean_host); +} + +static void common_resolve_redirects(std::string & url, httplib::Headers & headers) { + for (int r = 0; r < 5; ++r) { + auto [cli, parts] = common_http_client(url); + cli.set_follow_location(false); + common_set_clean_host_header(headers, parts.host); + + httplib::Headers probe_headers = headers; + probe_headers.emplace("Range", "bytes=0-0"); + + auto head = cli.Get(parts.path, probe_headers); + + if (head && (head->status >= 300 && head->status < 400) && head->has_header("Location")) { + std::string location = head->get_header_value("Location"); + if (location.find("://") == std::string::npos) { + url = parts.scheme + "://" + parts.host + location; + } else { + url = location; + } + if (headers.count("Authorization")) { + headers.erase("Authorization"); + } + continue; + } + break; + } + auto parts = common_http_parse_url(url); + common_set_clean_host_header(headers, parts.host); +} + // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; - auto [cli, parts] = common_http_client(url); - httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; if (!bearer_token.empty()) { default_headers.insert({"Authorization", "Bearer " + bearer_token}); } + for (const auto & h : custom_headers) { + default_headers.emplace(h.first, h.second); + } + + std::string real_url = url; + common_resolve_redirects(real_url, default_headers); + + auto [cli, parts] = common_http_client(real_url); cli.set_default_headers(default_headers); const bool file_exists = std::filesystem::exists(path); @@ -589,7 +647,9 @@ static bool common_download_file_single_online(const std::string & url, for (int i = 0; i < max_attempts; ++i) { auto head = cli.Head(parts.path); bool head_ok = head && head->status >= 200 && head->status < 300; - if (!head_ok) { + bool head_403 = head && head->status == 403; + + if (!head_ok && !head_403) { LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1); if (file_exists) { LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str()); @@ -598,22 +658,26 @@ static bool common_download_file_single_online(const std::string & url, } std::string etag; - if (head_ok && head->has_header("ETag")) { - etag = head->get_header_value("ETag"); - } - size_t total_size = 0; - if (head_ok && head->has_header("Content-Length")) { - try { - total_size = std::stoull(head->get_header_value("Content-Length")); - } catch (const std::exception& e) { - LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what()); - } - } - bool supports_ranges = false; - if (head_ok && head->has_header("Accept-Ranges")) { - supports_ranges = head->get_header_value("Accept-Ranges") != "none"; + + if (head_ok) { + if (head->has_header("ETag")) { + etag = head->get_header_value("ETag"); + } + if (head->has_header("Content-Length")) { + try { + total_size = std::stoull(head->get_header_value("Content-Length")); + } catch (const std::exception& e) { + LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what()); + } + } + if (head->has_header("Accept-Ranges")) { + supports_ranges = head->get_header_value("Accept-Ranges") != "none"; + } + } else if (head_403) { + LOG_INF("%s: 403 on HEAD, assuming GET/Resume is allowed\n", __func__); + supports_ranges = true; } bool should_download_from_scratch = false; @@ -648,8 +712,9 @@ static bool common_download_file_single_online(const std::string & url, } // start the download + std::string masked_url = common_http_show_masked_url(common_http_parse_url(url)); LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", - __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); + __func__, masked_url.c_str(), path_temporary.c_str(), etag.c_str()); const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); if (!was_pull_successful) { if (i + 1 < max_attempts) { @@ -680,13 +745,9 @@ std::pair> common_remote_get_content(const std::string auto [cli, parts] = common_http_client(url); httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; + for (const auto & header : params.headers) { - size_t pos = header.find(':'); - if (pos != std::string::npos) { - headers.emplace(header.substr(0, pos), header.substr(pos + 1)); - } else { - headers.emplace(header, ""); - } + headers.emplace(header.first, header.second); } if (params.timeout > 0) { @@ -718,9 +779,10 @@ std::pair> common_remote_get_content(const std::string static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, - bool offline) { + bool offline, + const common_header_list & headers) { if (!offline) { - return common_download_file_single_online(url, path, bearer_token); + return common_download_file_single_online(url, path, bearer_token, headers); } if (!std::filesystem::exists(path)) { @@ -734,13 +796,24 @@ static bool common_download_file_single(const std::string & url, // download multiple files from remote URLs to local paths // the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { +static bool common_download_file_multiple(const std::vector> & urls, + const std::string & bearer_token, + bool offline, + const common_header_list & headers) { // Prepare download in parallel std::vector> futures_download; + futures_download.reserve(urls.size()); + for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token, offline); - }, item)); + futures_download.push_back( + std::async( + std::launch::async, + [&bearer_token, offline, &headers](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline, headers); + }, + item + ) + ); } // Wait for all downloads to complete @@ -753,17 +826,17 @@ static bool common_download_file_multiple(const std::vector(hf_repo_with_tag, ':'); std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; @@ -839,10 +915,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; // headers - std::vector headers; - headers.push_back("Accept: application/json"); + common_header_list headers = custom_headers; + headers.push_back({"Accept", "application/json"}); if (!bearer_token.empty()) { - headers.push_back("Authorization: Bearer " + bearer_token); + headers.push_back({"Authorization", "Bearer " + bearer_token}); } // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response // User-Agent header is already set in common_remote_get_content, no need to set it here @@ -913,8 +989,14 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons // Docker registry functions // -static std::string common_docker_get_token(const std::string & repo) { - std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull"; +static std::string common_docker_get_token(const std::string & repo, + const common_oci_params & oci_params) { + if (oci_params.auth_url.empty()) { + return ""; + } + std::string url = oci_params.auth_url + + "?service=" + oci_params.auth_service + + "&scope=repository:" + repo + ":pull"; common_remote_params params; auto res = common_remote_get_content(url, params); @@ -933,7 +1015,7 @@ static std::string common_docker_get_token(const std::string & repo) { return response["token"].get(); } -std::string common_docker_resolve_model(const std::string & docker) { +std::string common_docker_resolve_model(const std::string & docker, const common_oci_params & params) { // Parse ai/smollm2:135M-Q4_0 size_t colon_pos = docker.find(':'); std::string repo, tag; @@ -970,16 +1052,20 @@ std::string common_docker_resolve_model(const std::string & docker) { return normalized; }; - std::string token = common_docker_get_token(repo); // Get authentication token + std::string token = common_docker_get_token(repo, params); // Get authentication token // Get manifest // TODO: cache the manifest response so that it appears in the model list - const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; + const std::string url_prefix = params.registry_url + "/v2/" + repo; std::string manifest_url = url_prefix + "/manifests/" + tag; common_remote_params manifest_params; - manifest_params.headers.push_back("Authorization: Bearer " + token); - manifest_params.headers.push_back( - "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); + + if (!token.empty()) { + manifest_params.headers.push_back({"Authorization", "Bearer " + token}); + } + manifest_params.headers.push_back({"Accept", + "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json" + }); auto manifest_res = common_remote_get_content(manifest_url, manifest_params); if (manifest_res.first != 200) { throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); @@ -990,17 +1076,15 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string gguf_digest; // Find the GGUF layer if (manifest.contains("layers")) { for (const auto & layer : manifest["layers"]) { - if (layer.contains("mediaType")) { - std::string media_type = layer["mediaType"].get(); - if (media_type == "application/vnd.docker.ai.gguf.v3" || - media_type.find("gguf") != std::string::npos) { - gguf_digest = layer["digest"].get(); - break; - } + if (!layer.contains("mediaType") || !layer.contains("digest")) { + continue; + } + if (layer["mediaType"].get() == params.media_type) { + gguf_digest = layer["digest"].get(); + break; } } } - if (gguf_digest.empty()) { throw std::runtime_error("No GGUF layer found in Docker manifest"); } @@ -1016,7 +1100,7 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string local_path = fs_get_cache_file(model_filename); const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - if (!common_download_file_single(blob_url, local_path, token, false)) { + if (!common_download_file_single(blob_url, local_path, token, false, {})) { throw std::runtime_error("Failed to download Docker Model"); } @@ -1030,11 +1114,11 @@ std::string common_docker_resolve_model(const std::string & docker) { #else -common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { +common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } -bool common_download_model(const common_params_model &, const std::string &, bool) { +bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } diff --git a/common/download.h b/common/download.h index 45a6bd6bba8..9ca95f5a69d 100644 --- a/common/download.h +++ b/common/download.h @@ -1,12 +1,21 @@ #pragma once #include +#include struct common_params_model; -// -// download functionalities -// +using common_header = std::pair; +using common_header_list = std::vector; + +struct common_remote_params { + common_header_list headers; + long timeout = 0; // in seconds, 0 means no timeout + long max_size = 0; // unlimited if 0 +}; + +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); struct common_cached_model_info { std::string manifest_path; @@ -39,17 +48,31 @@ struct common_hf_file_res { common_hf_file_res common_get_hf_file( const std::string & hf_repo_with_tag, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns true if download succeeded bool common_download_model( const common_params_model & model, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns list of cached models std::vector common_list_cached_models(); +struct common_oci_params { + std::string registry_url = "https://registry-1.docker.io"; + std::string auth_url = "https://auth.docker.io/token"; + std::string auth_service = "registry.docker.io"; + std::string media_type = "application/vnd.docker.ai.gguf.v3"; +}; + // resolve and download model from Docker registry // return local path to downloaded model file -std::string common_docker_resolve_model(const std::string & docker); +std::string common_docker_resolve_model( + const std::string & docker, + const common_oci_params & params = {} +); diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index a60ca12fe59..5cde09e6c58 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -1,5 +1,6 @@ #include "arg.h" #include "common.h" +#include "download.h" #include #include diff --git a/tools/run/run.cpp b/tools/run/run.cpp index b90a7253c43..b63f04976eb 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "llama-cpp.h" #include "log.h" +#include "download.h" #include "linenoise.cpp/linenoise.h" @@ -21,12 +22,6 @@ # include #endif -#if defined(LLAMA_USE_CURL) -# include -#else -# include "http.h" -#endif - #include #include @@ -303,24 +298,6 @@ class Opt { } }; -struct progress_data { - size_t file_size = 0; - std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); - bool printed = false; -}; - -static int get_terminal_width() { -#if defined(_WIN32) - CONSOLE_SCREEN_BUFFER_INFO csbi; - GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); - return csbi.srWindow.Right - csbi.srWindow.Left + 1; -#else - struct winsize w; - ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); - return w.ws_col; -#endif -} - class File { public: FILE * file = nullptr; @@ -400,368 +377,6 @@ class File { # endif }; -class HttpClient { - public: - int init(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - if (std::filesystem::exists(output_file)) { - return 0; - } - - std::string output_file_partial; - - if (!output_file.empty()) { - output_file_partial = output_file + ".partial"; - } - - if (download(url, headers, output_file_partial, progress, response_str)) { - return 1; - } - - if (!output_file.empty()) { - try { - std::filesystem::rename(output_file_partial, output_file); - } catch (const std::filesystem::filesystem_error & e) { - printe("Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what()); - return 1; - } - } - - return 0; - } - -#ifdef LLAMA_USE_CURL - - ~HttpClient() { - if (chunk) { - curl_slist_free_all(chunk); - } - - if (curl) { - curl_easy_cleanup(curl); - } - } - - private: - CURL * curl = nullptr; - struct curl_slist * chunk = nullptr; - - int download(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - curl = curl_easy_init(); - if (!curl) { - return 1; - } - - progress_data data; - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - - return 1; - } - - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - - return 1; - } - } - - set_write_options(response_str, out); - data.file_size = set_resume_point(output_file); - set_progress_options(progress, data); - set_headers(headers); - CURLcode res = perform(url); - if (res != CURLE_OK){ - printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); - return 1; - } - - return 0; - } - - void set_write_options(std::string * response_str, const File & out) { - if (response_str) { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str); - } else { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file); - } - } - - size_t set_resume_point(const std::string & output_file) { - size_t file_size = 0; - if (std::filesystem::exists(output_file)) { - file_size = std::filesystem::file_size(output_file); - curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast(file_size)); - } - - return file_size; - } - - void set_progress_options(bool progress, progress_data & data) { - if (progress) { - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); - curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data); - curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress); - } - } - - void set_headers(const std::vector & headers) { - if (!headers.empty()) { - if (chunk) { - curl_slist_free_all(chunk); - chunk = 0; - } - - for (const auto & header : headers) { - chunk = curl_slist_append(chunk, header.c_str()); - } - - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk); - } - } - - CURLcode perform(const std::string & url) { - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); - curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); -#ifdef _WIN32 - curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - return curl_easy_perform(curl); - } - -#else // LLAMA_USE_CURL is not defined - -#define curl_off_t long long // temporary hack - - private: - // this is a direct translation of the cURL download() above - int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - try { - auto [cli, url_parts] = common_http_client(url); - - httplib::Headers headers; - for (const auto & h : headers_vec) { - size_t pos = h.find(':'); - if (pos != std::string::npos) { - headers.emplace(h.substr(0, pos), h.substr(pos + 2)); - } - } - - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - return 1; - } - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - return 1; - } - } - - size_t resume_offset = 0; - if (!output_file.empty() && std::filesystem::exists(output_file)) { - resume_offset = std::filesystem::file_size(output_file); - if (resume_offset > 0) { - headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); - } - } - - progress_data data; - data.file_size = resume_offset; - - long long total_size = 0; - long long received_this_session = 0; - - auto response_handler = - [&](const httplib::Response & response) { - if (resume_offset > 0 && response.status != 206) { - printe("\nServer does not support resuming. Restarting download.\n"); - out.file = freopen(output_file.c_str(), "wb", out.file); - if (!out.file) { - return false; - } - data.file_size = 0; - } - if (progress) { - if (response.has_header("Content-Length")) { - total_size = std::stoll(response.get_header_value("Content-Length")); - } else if (response.has_header("Content-Range")) { - auto range = response.get_header_value("Content-Range"); - auto slash = range.find('/'); - if (slash != std::string::npos) { - total_size = std::stoll(range.substr(slash + 1)); - } - } - } - return true; - }; - - auto content_receiver = - [&](const char * chunk, size_t length) { - if (out.file && fwrite(chunk, 1, length, out.file) != length) { - return false; - } - if (response_str) { - response_str->append(chunk, length); - } - received_this_session += length; - - if (progress && total_size > 0) { - update_progress(&data, total_size, received_this_session, 0, 0); - } - return true; - }; - - auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); - - if (data.printed) { - printe("\n"); - } - - if (!res) { - auto err = res.error(); - printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); - return 1; - } - - if (res->status >= 400) { - printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); - return 1; - } - - } catch (const std::exception & e) { - printe("HTTP request failed: %s\n", e.what()); - return 1; - } - return 0; - } - -#endif // LLAMA_USE_CURL - - static std::string human_readable_time(double seconds) { - int hrs = static_cast(seconds) / 3600; - int mins = (static_cast(seconds) % 3600) / 60; - int secs = static_cast(seconds) % 60; - - if (hrs > 0) { - return string_format("%dh %02dm %02ds", hrs, mins, secs); - } else if (mins > 0) { - return string_format("%dm %02ds", mins, secs); - } else { - return string_format("%ds", secs); - } - } - - static std::string human_readable_size(curl_off_t size) { - static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" }; - char length = sizeof(suffix) / sizeof(suffix[0]); - int i = 0; - double dbl_size = size; - if (size > 1024) { - for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) { - dbl_size = size / 1024.0; - } - } - - return string_format("%.2f %s", dbl_size, suffix[i]); - } - - static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t, - curl_off_t) { - progress_data * data = static_cast(ptr); - if (total_to_download <= 0) { - return 0; - } - - total_to_download += data->file_size; - const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size; - const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download); - std::string progress_prefix = generate_progress_prefix(percentage); - - const double speed = calculate_speed(now_downloaded, data->start_time); - const double tim = (total_to_download - now_downloaded) / speed; - std::string progress_suffix = - generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim); - - int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix); - std::string progress_bar; - generate_progress_bar(progress_bar_width, percentage, progress_bar); - - print_progress(progress_prefix, progress_bar, progress_suffix); - data->printed = true; - - return 0; - } - - static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) { - return (now_downloaded_plus_file_size * 100) / total_to_download; - } - - static std::string generate_progress_prefix(curl_off_t percentage) { - return string_format("%3ld%% |", static_cast(percentage)); - } - - static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) { - const auto now = std::chrono::steady_clock::now(); - const std::chrono::duration elapsed_seconds = now - start_time; - return now_downloaded / elapsed_seconds.count(); - } - - static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download, - double speed, double estimated_time) { - const int width = 10; - return string_format("%*s/%*s%*s/s%*s", width, human_readable_size(now_downloaded_plus_file_size).c_str(), - width, human_readable_size(total_to_download).c_str(), width, - human_readable_size(speed).c_str(), width, human_readable_time(estimated_time).c_str()); - } - - static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) { - int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3; - if (progress_bar_width < 1) { - progress_bar_width = 1; - } - - return progress_bar_width; - } - - static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage, - std::string & progress_bar) { - const curl_off_t pos = (percentage * progress_bar_width) / 100; - for (int i = 0; i < progress_bar_width; ++i) { - progress_bar.append((i < pos) ? "█" : " "); - } - - return progress_bar; - } - - static void print_progress(const std::string & progress_prefix, const std::string & progress_bar, - const std::string & progress_suffix) { - printe("\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str()); - } - // Function to write data to a file - static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { - FILE * out = static_cast(stream); - return fwrite(ptr, size, nmemb, out); - } - - // Function to capture data into a string - static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) { - std::string * str = static_cast(stream); - str->append(static_cast(ptr), size * nmemb); - return size * nmemb; - } - -}; - class LlamaData { public: llama_model_ptr model; @@ -788,113 +403,37 @@ class LlamaData { } private: - int download(const std::string & url, const std::string & output_file, const bool progress, - const std::vector & headers = {}, std::string * response_str = nullptr) { - HttpClient http; - if (http.init(url, headers, output_file, progress, response_str)) { - return 1; - } - - return 0; - } - - // Helper function to handle model tag extraction and URL construction - std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { - std::string model_tag = "latest"; - const size_t colon_pos = model.find(':'); - if (colon_pos != std::string::npos) { - model_tag = model.substr(colon_pos + 1); - model = model.substr(0, colon_pos); - } - - std::string url = base_url + model + "/manifests/" + model_tag; - - return { model, url }; - } - // Helper function to download and parse the manifest - int download_and_parse_manifest(const std::string & url, const std::vector & headers, - nlohmann::json & manifest) { - std::string manifest_str; - int ret = download(url, "", false, headers, &manifest_str); - if (ret) { - return ret; - } - - manifest = nlohmann::json::parse(manifest_str); - - return 0; - } - - int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) { + static bool resolve_endpoint(const std::string & model_endpoint, + const std::string & model, + common_params_model & params) { // Find the second occurrence of '/' after protocol string size_t pos = model.find('/'); - pos = model.find('/', pos + 1); - std::string hfr, hff; - std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; - std::string url; - - if (pos == std::string::npos) { - auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/"); - hfr = model_name; - - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, headers, manifest); - if (ret) { - return ret; - } - - hff = manifest["ggufFile"]["rfilename"]; - } else { - hfr = model.substr(0, pos); - hff = model.substr(pos + 1); - } - - url = model_endpoint + hfr + "/resolve/main/" + hff; - - return download(url, bn, true, headers); - } + pos = model.find('/', pos + 1); - int modelscope_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = "https://modelscope.cn/models/"; - return dl_from_endpoint(model_endpoint, model, bn); - } - - int huggingface_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = get_model_endpoint(); - return dl_from_endpoint(model_endpoint, model, bn); - } - - int ollama_dl(std::string & model, const std::string & bn) { - const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; - if (model.find('/') == std::string::npos) { - model = "library/" + model; - } - - auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, {}, manifest); - if (ret) { - return ret; - } + common_hf_file_res res; - std::string layer; - for (const auto & l : manifest["layers"]) { - if (l["mediaType"] == "application/vnd.ollama.image.model") { - layer = l["digest"]; - break; + try { + if (pos == std::string::npos) { + res = common_get_hf_file(model, "", false); + } else { + res.repo = model.substr(0, pos); + res.ggufFile = model.substr(pos + 1); } + } catch (const std::exception & e) { + printe("Invalid repository format\n"); + return false; } - std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; - - return download(blob_url, bn, true, headers); + params.url = model_endpoint + res.repo + "/resolve/main/" + res.ggufFile; + return true; } - int github_dl(const std::string & model, const std::string & bn) { - std::string repository = model; - std::string branch = "main"; - const size_t at_pos = model.find('@'); + static bool resolve_github(std::string & model, common_params_model & params) { + std::string repository = model; + std::string branch = "main"; + + const size_t at_pos = model.find('@'); if (at_pos != std::string::npos) { repository = model.substr(0, at_pos); branch = model.substr(at_pos + 1); @@ -903,113 +442,149 @@ class LlamaData { const std::vector repo_parts = string_split(repository, "/"); if (repo_parts.size() < 3) { printe("Invalid GitHub repository format\n"); - return 1; + return false; } - const std::string & org = repo_parts[0]; - const std::string & project = repo_parts[1]; - std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch; + const std::string & org = repo_parts[0]; + const std::string & project = repo_parts[1]; + std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch; + for (size_t i = 2; i < repo_parts.size(); ++i) { url += "/" + repo_parts[i]; } - return download(url, bn, true); + params.url = url; + return true; } - int s3_dl(const std::string & model, const std::string & bn) { + static bool resolve_s3(const std::string & model, common_params_model & params, common_header_list & headers) { const size_t slash_pos = model.find('/'); if (slash_pos == std::string::npos) { - return 1; + return false; } - const std::string bucket = model.substr(0, slash_pos); - const std::string key = model.substr(slash_pos + 1); + const std::string bucket = model.substr(0, slash_pos); + const std::string key = model.substr(slash_pos + 1); + const char * access_key = std::getenv("AWS_ACCESS_KEY_ID"); const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY"); + if (!access_key || !secret_key) { printe("AWS credentials not found in environment\n"); - return 1; + return false; } // Generate AWS Signature Version 4 headers // (Implementation requires HMAC-SHA256 and date handling) // Get current timestamp - const time_t now = time(nullptr); - const tm tm = *gmtime(&now); - const std::string date = strftime_fmt("%Y%m%d", tm); - const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm); - const std::vector headers = { - "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date + - "/us-east-1/s3/aws4_request", - "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime - }; - - const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key; - - return download(url, bn, true, headers); - } + const time_t now = time(nullptr); + const tm tm = *gmtime(&now); + const std::string date = strftime_fmt("%Y%m%d", tm); + const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm); + const std::string auth_header = "AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date + "/us-east-1/s3/aws4_request"; - std::string basename(const std::string & path) { - const size_t pos = path.find_last_of("/\\"); - if (pos == std::string::npos) { - return path; - } + headers.push_back({"Authorization", auth_header}); + headers.push_back({"x-amz-content-sha256", "UNSIGNED-PAYLOAD"}); + headers.push_back({"x-amz-date", datetime}); - return path.substr(pos + 1); + params.url = "https://" + bucket + ".s3.amazonaws.com/" + key; + return true; } - int rm_until_substring(std::string & model_, const std::string & substring) { - const std::string::size_type pos = model_.find(substring); - if (pos == std::string::npos) { - return 1; + static bool remove_prefix(std::string & url, const std::string & prefix) { + if (string_starts_with(url, prefix)) { + url = url.substr(prefix.length()); + return true; } - - model_ = model_.substr(pos + substring.size()); // Skip past the substring - return 0; + return false; } - int resolve_model(std::string & model_) { - int ret = 0; - if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { - rm_until_substring(model_, "://"); - - return ret; + static int resolve_model(std::string & model_) { + if (std::filesystem::exists(model_)) { + return 0; } - const std::string bn = basename(model_); - if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") || - string_starts_with(model_, "hf.co/")) { - rm_until_substring(model_, "hf.co/"); - rm_until_substring(model_, "://"); - ret = huggingface_dl(model_, bn); - } else if (string_starts_with(model_, "ms://") || string_starts_with(model_, "modelscope://")) { - rm_until_substring(model_, "://"); - ret = modelscope_dl(model_, bn); - } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) && - !string_starts_with(model_, "https://ollama.com/library/")) { - ret = download(model_, bn, true); - } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) { - rm_until_substring(model_, "github:"); - rm_until_substring(model_, "://"); - ret = github_dl(model_, bn); - } else if (string_starts_with(model_, "s3://")) { - rm_until_substring(model_, "://"); - ret = s3_dl(model_, bn); - } else { // ollama:// or nothing - rm_until_substring(model_, "ollama.com/library/"); - rm_until_substring(model_, "://"); - ret = ollama_dl(model_, bn); + common_params_model m_params; + common_header_list headers; + common_oci_params oci_params; + + bool is_ollama = false; + + if (remove_prefix(model_, "file://")) { + if (std::filesystem::exists(model_)) { + return 0; + } + } else if (remove_prefix(model_, "hf://") || + remove_prefix(model_, "huggingface://")) { + if (!resolve_endpoint(get_model_endpoint(), model_, m_params)) { + return 1; + } + } else if (remove_prefix(model_, "ms://") || + remove_prefix(model_, "modelscope://")) { + if (!resolve_endpoint("https://modelscope.cn/models/", model_, m_params)) { + return 1; + } + } else if (remove_prefix(model_, "s3://")) { + if (!resolve_s3(model_, m_params, headers)) { + return 1; + } + } else if (remove_prefix(model_, "github://")) { + if (!resolve_github(model_, m_params)) { + return 1; + } + } else if (remove_prefix(model_, "ollama://") || + remove_prefix(model_, "https://ollama.com/library/")) { + is_ollama = true; + } else if (string_starts_with(model_, "http://") || + string_starts_with(model_, "https://")) { + m_params.url = model_; + } else { + if (model_.find(".gguf") != std::string::npos) { + printe("Error: Local file not found: %s\n", model_.c_str()); + return 1; + } + // fallback ollama + is_ollama = true; } + try { + if (is_ollama) { + oci_params.registry_url = "https://registry.ollama.ai"; + oci_params.auth_url = ""; // no auth for ollama + oci_params.auth_service = ""; + oci_params.media_type = "application/vnd.ollama.image.model"; + + if (model_.find('/') == std::string::npos) { + model_ = "library/" + model_; + } + model_ = common_docker_resolve_model(model_, oci_params); + } else { + std::string name = std::filesystem::path(m_params.url).filename().string(); - model_ = bn; + if (name.find('?') != std::string::npos) { + name = name.substr(0, name.find('?')); + } + m_params.path = fs_get_cache_file(name); - return ret; + // token and offline are not supported + if (!common_download_model(m_params, "", false, headers)) { + printe("Failed to download model from %s\n", m_params.url.c_str()); + return 1; + } + model_ = m_params.path; + } + } catch (const std::exception & e) { + printe("Model resolution error: %s\n", e.what()); + return 1; + } + return 0; } // Initializes the model and returns a unique pointer to it - llama_model_ptr initialize_model(Opt & opt) { + static llama_model_ptr initialize_model(Opt & opt) { ggml_backend_load_all(); - resolve_model(opt.model_); + if (resolve_model(opt.model_)) { + return nullptr; + } printe("\r" LOG_CLR_TO_EOL "Loading model"); llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params)); if (!model) { @@ -1021,7 +596,7 @@ class LlamaData { } // Initializes the context with the specified parameters - llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { + static llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params)); if (!context) { printe("%s: error: failed to create the llama_context\n", __func__); @@ -1031,7 +606,7 @@ class LlamaData { } // Initializes and configures the sampler - llama_sampler_ptr initialize_sampler(const Opt & opt) { + static llama_sampler_ptr initialize_sampler(const Opt & opt) { llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature)); @@ -1043,7 +618,7 @@ class LlamaData { // Add a message to `messages` and store its content in `msg_strs` static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { - llama_data.msg_strs.push_back(std::move(text)); + llama_data.msg_strs.push_back(text); llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); } diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 18328f3afbd..31bc29314c5 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -6,6 +6,7 @@ #include "chat.h" #include "arg.h" // for common_remote_get_content; TODO: use download.h only #include "base64.hpp" +#include "download.h" #include "server-common.h" @@ -867,7 +868,7 @@ json oaicompat_chat_params_parse( // download remote image // TODO @ngxson : maybe make these params configurable common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.headers.push_back({"User-Agent", "llama.cpp/" + build_info}); params.max_size = 1024 * 1024 * 10; // 10MB params.timeout = 10; // seconds SRV_INF("downloading image from '%s'\n", url.c_str());