diff --git a/Cargo.toml b/Cargo.toml index 3fdc73a..6f40478 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,8 +37,8 @@ native-tls = { version = "0.2.11", optional = true } [features] default = ["online"] -online = ["dep:ureq", "dep:native-tls", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus", "dep:thiserror"] -tokio = ["dep:reqwest", "dep:tokio", "dep:futures", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus", "dep:thiserror"] +online = ["dep:ureq", "dep:native-tls", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:thiserror"] +tokio = ["dep:reqwest", "dep:tokio", "tokio/rt-multi-thread", "dep:futures", "dep:rand", "dep:serde", "dep:serde_json", "dep:indicatif", "dep:num_cpus", "dep:thiserror"] [dev-dependencies] hex-literal = "0.4.1" diff --git a/examples/download.rs b/examples/download.rs new file mode 100644 index 0000000..18e35e0 --- /dev/null +++ b/examples/download.rs @@ -0,0 +1,21 @@ +use hf_hub::{Repo, RepoType}; + +#[cfg(not(feature = "tokio"))] +fn main() { + let api = hf_hub::api::sync::Api::new().unwrap(); + let repo = Repo::new("meta-llama/Llama-2-7b-hf".to_string(), RepoType::Model); + + let _filename = api.get(&repo, "model-00001-of-00002.safetensors").unwrap(); +} + +#[cfg(feature = "tokio")] +#[tokio::main] +async fn main() { + let api = hf_hub::api::tokio::Api::new().unwrap(); + let repo = Repo::new("meta-llama/Llama-2-7b-hf".to_string(), RepoType::Model); + + let _filename = api + .get(&repo, "model-00001-of-00002.safetensors") + .await + .unwrap(); +} diff --git a/src/api/sync.rs b/src/api/sync.rs index 50909eb..676b68b 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,6 +1,6 @@ use crate::{Cache, Repo}; use indicatif::{ProgressBar, ProgressStyle}; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use rand::{distributions::Alphanumeric, Rng}; use std::collections::HashMap; // use reqwest::{ // blocking::Agent, @@ -12,7 +12,6 @@ use std::collections::HashMap; // Error as ReqwestError, // }; use super::RepoInfo; -use std::io::{Seek, SeekFrom, Write}; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; use thiserror::Error; @@ -93,9 +92,6 @@ pub struct ApiBuilder { cache: Cache, url_template: String, token: Option, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, progress: bool, } @@ -133,9 +129,6 @@ impl ApiBuilder { url_template: "{endpoint}/{repo_id}/resolve/{revision}/{filename}".to_string(), cache, token, - chunk_size: 10_000_000, - parallel_failures: 0, - max_retries: 0, progress, } } @@ -180,9 +173,6 @@ impl ApiBuilder { client, no_redirect_client, - chunk_size: self.chunk_size, - parallel_failures: self.parallel_failures, - max_retries: self.max_retries, progress: self.progress, }) } @@ -204,9 +194,6 @@ pub struct Api { cache: Cache, client: HeaderAgent, no_redirect_client: HeaderAgent, - chunk_size: usize, - parallel_failures: usize, - max_retries: usize, progress: bool, } @@ -272,14 +259,6 @@ fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { Ok(()) } -fn jitter() -> usize { - thread_rng().gen_range(0..=500) -} - -fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { - (base_wait_time + n.pow(2) + jitter()).min(max) -} - impl Api { /// Creates a default Api, for Api options See [`ApiBuilder`] pub fn new() -> Result { @@ -367,95 +346,29 @@ impl Api { fn download_tempfile( &self, url: &str, - length: usize, progressbar: Option, ) -> Result { let filename = temp_filename(); // Create the file and set everything properly - std::fs::File::create(&filename)?.set_len(length as u64)?; - - let chunk_size = self.chunk_size; - - let n_chunks = (length + chunk_size - 1) / chunk_size; - let n_threads = num_cpus::get(); - let chunks_per_thread = (n_chunks + n_threads - 1) / n_threads; - let handles = (0..n_threads).map(|thread_id| { - let url = url.to_string(); - let filename = filename.clone(); - let client = self.client.clone(); - let parallel_failures = self.parallel_failures; - let max_retries = self.max_retries; - let progress = progressbar.clone(); - std::thread::spawn(move || { - for chunk_id in chunks_per_thread * thread_id - ..std::cmp::min(chunks_per_thread * (thread_id + 1), n_chunks) - { - let start = chunk_id * chunk_size; - let stop = std::cmp::min(start + chunk_size - 1, length); - let mut chunk = Self::download_chunk(&client, &url, &filename, start, stop); - let mut i = 0; - if parallel_failures > 0 { - while let Err(dlerr) = chunk { - let wait_time = exponential_backoff(300, i, 10_000); - std::thread::sleep(std::time::Duration::from_millis(wait_time as u64)); - - chunk = Self::download_chunk(&client, &url, &filename, start, stop); - i += 1; - if i > max_retries { - return Err(ApiError::TooManyRetries(dlerr.into())); - } - } - } - if let Some(p) = &progress { - p.inc((stop - start) as u64); - } - chunk? - } - Ok(()) - }) - }); - - let results: Result, ApiError> = - handles.into_iter().flat_map(|h| h.join()).collect(); - - results?; - if let Some(p) = progressbar { - p.finish() - } - Ok(filename) - } + let mut file = std::fs::File::create(&filename)?; - fn download_chunk( - client: &HeaderAgent, - url: &str, - filename: &PathBuf, - start: usize, - stop: usize, - ) -> Result<(), ApiError> { - // Process each socket concurrently. - let range = format!("bytes={start}-{stop}"); - let mut file = std::fs::OpenOptions::new().write(true).open(filename)?; - file.seek(SeekFrom::Start(start as u64))?; - let response = client + let response = self.client .get(url) - .set(RANGE, &range) .call() .map_err(Box::new)?; - const MAX: usize = 4096; - let mut buffer: [u8; MAX] = [0; MAX]; let mut reader = response.into_reader(); - let mut remaining = stop - start; - while remaining > 0 { - let to_read = if remaining > MAX { MAX } else { remaining }; + if let Some(p) = &progressbar{ + reader = Box::new(p.wrap_read(reader)); + } + + std::io::copy(&mut reader, &mut file)?; - reader.read_exact(&mut buffer[0..to_read])?; - remaining -= to_read; - file.write_all(&buffer[0..to_read])?; + if let Some(p) = progressbar { + p.finish() } - // file.write_all(&content)?; - Ok(()) + Ok(filename) } /// This will attempt the fetch the file locally first, then [`Api.download`] @@ -510,7 +423,7 @@ impl Api { None }; - let tmp_filename = self.download_tempfile(&url, metadata.size, progressbar)?; + let tmp_filename = self.download_tempfile(&url, progressbar)?; if std::fs::rename(&tmp_filename, &blob_path).is_err() { // Renaming may fail if locations are different mount points diff --git a/src/lib.rs b/src/lib.rs index 1a4b68c..29ca708 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -97,16 +97,16 @@ impl Cache { impl Default for Cache { fn default() -> Self { - let path = match std::env::var("HF_HOME") { + let mut path = match std::env::var("HF_HOME") { Ok(home) => home.into(), Err(_) => { let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); cache.push(".cache"); cache.push("huggingface"); - cache.push("hub"); cache } }; + path.push("hub"); Self::new(path) } }