diff --git a/flake.nix b/flake.nix index ddd6bd4..fda9da7 100644 --- a/flake.nix +++ b/flake.nix @@ -22,6 +22,8 @@ default = pkgs.mkShell { buildInputs = with pkgs; [ rustup + pkg-config + openssl ]; }; diff --git a/src/api/mod.rs b/src/api/mod.rs index a5bc6a4..121f545 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,6 +9,8 @@ pub mod tokio; #[cfg(feature = "ureq")] pub mod sync; +const HF_ENDPOINT: &str = "HF_ENDPOINT"; + /// This trait is used by users of the lib /// to implement custom behavior during file downloads pub trait Progress { diff --git a/src/api/sync.rs b/src/api/sync.rs index ff191d3..3748c19 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,4 +1,4 @@ -use super::RepoInfo; +use super::{RepoInfo, HF_ENDPOINT}; use crate::api::sync::ApiError::InvalidHeader; use crate::api::Progress; use crate::{Cache, Repo, RepoType}; @@ -133,6 +133,23 @@ impl ApiBuilder { Self::from_cache(cache) } + /// Creates API with values potentially from environment variables. + /// HF_HOME decides the location of the cache folder + /// HF_ENDPOINT modifies the URL for the huggingface location + /// to download files from. + /// ``` + /// use hf_hub::api::sync::ApiBuilder; + /// let api = ApiBuilder::from_env().build().unwrap(); + /// ``` + pub fn from_env() -> Self { + let cache = Cache::from_env(); + let mut builder = Self::from_cache(cache); + if let Ok(endpoint) = std::env::var(HF_ENDPOINT) { + builder = builder.with_endpoint(endpoint); + } + builder + } + /// From a given cache /// ``` /// use hf_hub::{api::sync::ApiBuilder, Cache}; diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 15d06f9..82f79c2 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,5 +1,5 @@ use super::Progress as SyncProgress; -use super::RepoInfo; +use super::{RepoInfo, HF_ENDPOINT}; use crate::{Cache, Repo, RepoType}; use futures::StreamExt; use indicatif::ProgressBar; @@ -133,6 +133,23 @@ impl ApiBuilder { Self::from_cache(cache) } + /// Creates API with values potentially from environment variables. + /// HF_HOME decides the location of the cache folder + /// HF_ENDPOINT modifies the URL for the huggingface location + /// to download files from. + /// ``` + /// use hf_hub::api::tokio::ApiBuilder; + /// let api = ApiBuilder::from_env().build().unwrap(); + /// ``` + pub fn from_env() -> Self { + let cache = Cache::from_env(); + let mut builder = Self::from_cache(cache); + if let Ok(endpoint) = std::env::var(HF_ENDPOINT) { + builder = builder.with_endpoint(endpoint); + } + builder + } + /// High CPU download /// /// This may cause issues on regular desktops as it will saturate @@ -141,12 +158,10 @@ impl ApiBuilder { /// saturate the bandwidth (>500MB/s) better. /// ``` /// use hf_hub::api::tokio::ApiBuilder; - /// let api = ApiBuilder::high().build().unwrap(); + /// let api = ApiBuilder::new().high().build().unwrap(); /// ``` - pub fn high() -> Self { - let cache = Cache::default(); - Self::from_cache(cache) - .with_max_files(num_cpus::get()) + pub fn high(self) -> Self { + self.with_max_files(num_cpus::get()) .with_chunk_size(Some(10_000_000)) } diff --git a/src/lib.rs b/src/lib.rs index 7af706e..a832c28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,8 @@ use std::path::PathBuf; #[cfg(any(feature = "tokio", feature = "ureq"))] pub mod api; +const HF_HOME: &str = "HF_HOME"; + /// The type of repo to interact with #[derive(Debug, Clone, Copy)] pub enum RepoType { @@ -37,6 +39,19 @@ impl Cache { Self { path } } + /// Creates cache from environment variable HF_HOME (if defined) otherwise + /// defaults to [`home_dir`]/.cache/huggingface/ + pub fn from_env() -> Self { + match std::env::var(HF_HOME) { + Ok(home) => { + let mut path: PathBuf = home.into(); + path.push("hub"); + Self::new(path) + } + Err(_) => Self::default(), + } + } + /// Creates a new cache object location pub fn path(&self) -> &PathBuf { &self.path @@ -137,6 +152,7 @@ impl CacheRepo { fn new(cache: Cache, repo: Repo) -> Self { Self { cache, repo } } + /// This will get the location of the file within the cache for the remote /// `filename`. Will return `None` if file is not already present in cache. pub fn get(&self, filename: &str) -> Option { @@ -197,15 +213,9 @@ impl CacheRepo { impl Default for Cache { fn default() -> Self { - let mut path = match std::env::var("HF_HOME") { - Ok(home) => home.into(), - Err(_) => { - let mut cache = dirs::home_dir().expect("Cache directory cannot be found"); - cache.push(".cache"); - cache.push("huggingface"); - cache - } - }; + let mut path = dirs::home_dir().expect("Cache directory cannot be found"); + path.push(".cache"); + path.push("huggingface"); path.push("hub"); Self::new(path) } @@ -338,9 +348,9 @@ mod tests { #[test] #[cfg(not(target_os = "windows"))] fn token_path() { - let cache = Cache::default(); + let cache = Cache::from_env(); let token_path = cache.token_path().to_str().unwrap().to_string(); - if let Ok(hf_home) = std::env::var("HF_HOME") { + if let Ok(hf_home) = std::env::var(HF_HOME) { assert_eq!(token_path, format!("{hf_home}/token")); } else { let n = "huggingface/token".len(); @@ -351,9 +361,9 @@ mod tests { #[test] #[cfg(target_os = "windows")] fn token_path() { - let cache = Cache::default(); + let cache = Cache::from_env(); let token_path = cache.token_path().to_str().unwrap().to_string(); - if let Ok(hf_home) = std::env::var("HF_HOME") { + if let Ok(hf_home) = std::env::var(HF_HOME) { assert_eq!(token_path, format!("{hf_home}\\token")); } else { let n = "huggingface/token".len();