Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions common/arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,3 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e

// function to be used by test-arg-parser
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

struct common_remote_params {
std::vector<std::string> headers;
long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout
long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB
};
// get remote file content, returns <http_code, raw_response_body>
std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url, const common_remote_params & params);
208 changes: 146 additions & 62 deletions common/download.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ static bool common_download_head(CURL * curl,
// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;
for (int i = 0; i < max_attempts; ++i) {
Expand All @@ -325,6 +326,11 @@ static bool common_download_file_single_online(const std::string & url,
common_load_model_from_url_headers headers;
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
curl_slist_ptr http_headers;

for (const auto & h : custom_headers) {
std::string s = h.first + ": " + h.second;
http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str());
}
const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token);
if (!was_perform_successful) {
head_request_ok = false;
Expand Down Expand Up @@ -449,8 +455,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size);
}
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");

for (const auto & header : params.headers) {
http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str());
std::string header_ = header.first + ": " + header.second;
http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str());
}
curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr);

Expand Down Expand Up @@ -562,19 +570,69 @@ static bool common_pull_file(httplib::Client & cli,
return true;
}

static void common_set_clean_host_header(httplib::Headers & headers, const std::string & host) {
if (headers.count("Host")) {
headers.erase("Host");
}

std::string clean_host = host;
size_t pos = clean_host.find(':');
if (pos != std::string::npos) {
clean_host = clean_host.substr(0, pos);
}

headers.emplace("Host", clean_host);
}

static void common_resolve_redirects(std::string & url, httplib::Headers & headers) {
for (int r = 0; r < 5; ++r) {
auto [cli, parts] = common_http_client(url);
cli.set_follow_location(false);
common_set_clean_host_header(headers, parts.host);

httplib::Headers probe_headers = headers;
probe_headers.emplace("Range", "bytes=0-0");

auto head = cli.Get(parts.path, probe_headers);

if (head && (head->status >= 300 && head->status < 400) && head->has_header("Location")) {
std::string location = head->get_header_value("Location");
if (location.find("://") == std::string::npos) {
url = parts.scheme + "://" + parts.host + location;
} else {
url = location;
}
if (headers.count("Authorization")) {
headers.erase("Authorization");
}
continue;
}
break;
}
auto parts = common_http_parse_url(url);
common_set_clean_host_header(headers, parts.host);
}

// download one single file from remote URL to local path
static bool common_download_file_single_online(const std::string & url,
const std::string & path,
const std::string & bearer_token) {
const std::string & bearer_token,
const common_header_list & custom_headers) {
static const int max_attempts = 3;
static const int retry_delay_seconds = 2;

auto [cli, parts] = common_http_client(url);

httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}};
if (!bearer_token.empty()) {
default_headers.insert({"Authorization", "Bearer " + bearer_token});
}
for (const auto & h : custom_headers) {
default_headers.emplace(h.first, h.second);
}

std::string real_url = url;
common_resolve_redirects(real_url, default_headers);

auto [cli, parts] = common_http_client(real_url);
cli.set_default_headers(default_headers);

const bool file_exists = std::filesystem::exists(path);
Expand All @@ -589,7 +647,9 @@ static bool common_download_file_single_online(const std::string & url,
for (int i = 0; i < max_attempts; ++i) {
auto head = cli.Head(parts.path);
bool head_ok = head && head->status >= 200 && head->status < 300;
if (!head_ok) {
bool head_403 = head && head->status == 403;

if (!head_ok && !head_403) {
LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1);
if (file_exists) {
LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str());
Expand All @@ -598,22 +658,26 @@ static bool common_download_file_single_online(const std::string & url,
}

std::string etag;
if (head_ok && head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}

size_t total_size = 0;
if (head_ok && head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}

bool supports_ranges = false;
if (head_ok && head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";

if (head_ok) {
if (head->has_header("ETag")) {
etag = head->get_header_value("ETag");
}
if (head->has_header("Content-Length")) {
try {
total_size = std::stoull(head->get_header_value("Content-Length"));
} catch (const std::exception& e) {
LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what());
}
}
if (head->has_header("Accept-Ranges")) {
supports_ranges = head->get_header_value("Accept-Ranges") != "none";
}
} else if (head_403) {
LOG_INF("%s: 403 on HEAD, assuming GET/Resume is allowed\n", __func__);
supports_ranges = true;
}

bool should_download_from_scratch = false;
Expand Down Expand Up @@ -648,8 +712,9 @@ static bool common_download_file_single_online(const std::string & url,
}

// start the download
std::string masked_url = common_http_show_masked_url(common_http_parse_url(url));
LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n",
__func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str());
__func__, masked_url.c_str(), path_temporary.c_str(), etag.c_str());
const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size);
if (!was_pull_successful) {
if (i + 1 < max_attempts) {
Expand Down Expand Up @@ -680,13 +745,9 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
auto [cli, parts] = common_http_client(url);

httplib::Headers headers = {{"User-Agent", "llama-cpp"}};

for (const auto & header : params.headers) {
size_t pos = header.find(':');
if (pos != std::string::npos) {
headers.emplace(header.substr(0, pos), header.substr(pos + 1));
} else {
headers.emplace(header, "");
}
headers.emplace(header.first, header.second);
}

if (params.timeout > 0) {
Expand Down Expand Up @@ -718,9 +779,10 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string
static bool common_download_file_single(const std::string & url,
const std::string & path,
const std::string & bearer_token,
bool offline) {
bool offline,
const common_header_list & headers) {
if (!offline) {
return common_download_file_single_online(url, path, bearer_token);
return common_download_file_single_online(url, path, bearer_token, headers);
}

if (!std::filesystem::exists(path)) {
Expand All @@ -734,13 +796,24 @@ static bool common_download_file_single(const std::string & url,

// download multiple files from remote URLs to local paths
// the input is a vector of pairs <url, path>
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls, const std::string & bearer_token, bool offline) {
static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Prepare download in parallel
std::vector<std::future<bool>> futures_download;
futures_download.reserve(urls.size());

for (auto const & item : urls) {
futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline);
}, item));
futures_download.push_back(
std::async(
std::launch::async,
[&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
return common_download_file_single(it.first, it.second, bearer_token, offline, headers);
},
item
)
);
}

// Wait for all downloads to complete
Expand All @@ -753,17 +826,17 @@ static bool common_download_file_multiple(const std::vector<std::pair<std::strin
return true;
}

bool common_download_model(
const common_params_model & model,
const std::string & bearer_token,
bool offline) {
bool common_download_model(const common_params_model & model,
const std::string & bearer_token,
bool offline,
const common_header_list & headers) {
// Basic validation of the model.url
if (model.url.empty()) {
LOG_ERR("%s: invalid model url\n", __func__);
return false;
}

if (!common_download_file_single(model.url, model.path, bearer_token, offline)) {
if (!common_download_file_single(model.url, model.path, bearer_token, offline, headers)) {
return false;
}

Expand Down Expand Up @@ -822,13 +895,16 @@ bool common_download_model(
}

// Download in parallel
common_download_file_multiple(urls, bearer_token, offline);
common_download_file_multiple(urls, bearer_token, offline, headers);
}

return true;
}

common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, const std::string & bearer_token, bool offline) {
common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
const std::string & bearer_token,
bool offline,
const common_header_list & custom_headers) {
auto parts = string_split<std::string>(hf_repo_with_tag, ':');
std::string tag = parts.size() > 1 ? parts.back() : "latest";
std::string hf_repo = parts[0];
Expand All @@ -839,10 +915,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;

// headers
std::vector<std::string> headers;
headers.push_back("Accept: application/json");
common_header_list headers = custom_headers;
headers.push_back({"Accept", "application/json"});
if (!bearer_token.empty()) {
headers.push_back("Authorization: Bearer " + bearer_token);
headers.push_back({"Authorization", "Bearer " + bearer_token});
}
// Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response
// User-Agent header is already set in common_remote_get_content, no need to set it here
Expand Down Expand Up @@ -913,8 +989,14 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons
// Docker registry functions
//

static std::string common_docker_get_token(const std::string & repo) {
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
static std::string common_docker_get_token(const std::string & repo,
const common_oci_params & oci_params) {
if (oci_params.auth_url.empty()) {
return "";
}
std::string url = oci_params.auth_url
+ "?service=" + oci_params.auth_service
+ "&scope=repository:" + repo + ":pull";

common_remote_params params;
auto res = common_remote_get_content(url, params);
Expand All @@ -933,7 +1015,7 @@ static std::string common_docker_get_token(const std::string & repo) {
return response["token"].get<std::string>();
}

std::string common_docker_resolve_model(const std::string & docker) {
std::string common_docker_resolve_model(const std::string & docker, const common_oci_params & params) {
// Parse ai/smollm2:135M-Q4_0
size_t colon_pos = docker.find(':');
std::string repo, tag;
Expand Down Expand Up @@ -970,16 +1052,20 @@ std::string common_docker_resolve_model(const std::string & docker) {
return normalized;
};

std::string token = common_docker_get_token(repo); // Get authentication token
std::string token = common_docker_get_token(repo, params); // Get authentication token

// Get manifest
// TODO: cache the manifest response so that it appears in the model list
const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
const std::string url_prefix = params.registry_url + "/v2/" + repo;
std::string manifest_url = url_prefix + "/manifests/" + tag;
common_remote_params manifest_params;
manifest_params.headers.push_back("Authorization: Bearer " + token);
manifest_params.headers.push_back(
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");

if (!token.empty()) {
manifest_params.headers.push_back({"Authorization", "Bearer " + token});
}
manifest_params.headers.push_back({"Accept",
"application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
});
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
if (manifest_res.first != 200) {
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
Expand All @@ -990,17 +1076,15 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string gguf_digest; // Find the GGUF layer
if (manifest.contains("layers")) {
for (const auto & layer : manifest["layers"]) {
if (layer.contains("mediaType")) {
std::string media_type = layer["mediaType"].get<std::string>();
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
media_type.find("gguf") != std::string::npos) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
if (!layer.contains("mediaType") || !layer.contains("digest")) {
continue;
}
if (layer["mediaType"].get<std::string>() == params.media_type) {
gguf_digest = layer["digest"].get<std::string>();
break;
}
}
}

if (gguf_digest.empty()) {
throw std::runtime_error("No GGUF layer found in Docker manifest");
}
Expand All @@ -1016,7 +1100,7 @@ std::string common_docker_resolve_model(const std::string & docker) {
std::string local_path = fs_get_cache_file(model_filename);

const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
if (!common_download_file_single(blob_url, local_path, token, false)) {
if (!common_download_file_single(blob_url, local_path, token, false, {})) {
throw std::runtime_error("Failed to download Docker Model");
}

Expand All @@ -1030,11 +1114,11 @@ std::string common_docker_resolve_model(const std::string & docker) {

#else

common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) {
common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}

bool common_download_model(const common_params_model &, const std::string &, bool) {
bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) {
throw std::runtime_error("download functionality is not enabled in this build");
}

Expand Down
Loading
Loading