diff --git a/README.md b/README.md index df6a5ff..9d8e56d 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,36 @@ it should be consistent with [huggingface_hub](https://github.com/huggingface/hu At this time only a limited subset of the functionality is present, the goal is to add new features over time. We are currently treating this as an internel/external tool, meaning -we will treat what exists as public, and keep backward compatibility in the same regard. +we will are currently modifying everything at will for out internal needs. This will eventually +stabilize as it matures to accomodate most of our needs. + +If you're interested in using this, you're welcome to do it but be warned about potential changing grounds. + +If you want to contribute, you are more than welcome. However allowing new features or creating new features might be denied by lack of maintainability time. We're focusing on what we currently internally need. Hopefully that subset is already interesting to more users. + + +# How to use + +Add the dependency + +```bash +cargo add hf-hub # --features tokio +``` +`tokio` feature will enable an async (and potentially faster) API. + +Use the crate: + +```rust +use hf_hub::api::sync::Api; + +let api = Api::new().unwrap(); + +let repo = api.model("bert-base-uncased".to_string()); +let _filename = repo.get("config.json").unwrap(); + +// filename is now the local location within hf cache of the config.json file +``` diff --git a/examples/download.rs b/examples/download.rs index 18e35e0..4ca4970 100644 --- a/examples/download.rs +++ b/examples/download.rs @@ -1,21 +1,14 @@ -use hf_hub::{Repo, RepoType}; - #[cfg(not(feature = "tokio"))] fn main() { let api = hf_hub::api::sync::Api::new().unwrap(); - let repo = Repo::new("meta-llama/Llama-2-7b-hf".to_string(), RepoType::Model); - let _filename = api.get(&repo, "model-00001-of-00002.safetensors").unwrap(); + let _filename = api.model("meta-llama/Llama-2-7b-hf".to_string()).get("model-00001-of-00002.safetensors").unwrap(); } #[cfg(feature = "tokio")] #[tokio::main] async fn main() { let api = hf_hub::api::tokio::Api::new().unwrap(); - let repo = Repo::new("meta-llama/Llama-2-7b-hf".to_string(), RepoType::Model); - let _filename = api - .get(&repo, "model-00001-of-00002.safetensors") - .await - .unwrap(); + let _filename = api.model("meta-llama/Llama-2-7b-hf".to_string()).get("model-00001-of-00002.safetensors").await.unwrap(); } diff --git a/src/api/sync.rs b/src/api/sync.rs index 676b68b..4742976 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,4 +1,4 @@ -use crate::{Cache, Repo}; +use crate::{Cache, Repo, RepoType}; use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, Rng}; use std::collections::HashMap; @@ -188,6 +188,7 @@ struct Metadata { /// The actual Api used to interacto with the hub. /// You can inspect repos with [`Api::info`] /// or download files with [`Api::download`] +#[derive(Clone)] pub struct Api { endpoint: String, url_template: String, @@ -265,24 +266,6 @@ impl Api { ApiBuilder::new().build() } - /// Get the fully qualified URL of the remote filename - /// ``` - /// # use hf_hub::{api::sync::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let url = api.url(&repo, "model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); - /// ``` - pub fn url(&self, repo: &Repo, filename: &str) -> String { - let endpoint = &self.endpoint; - let revision = &repo.url_revision(); - self.url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) - } - /// Get the underlying api client /// Allows for lower level access pub fn client(&self) -> &HeaderAgent { @@ -371,18 +354,89 @@ impl Api { Ok(filename) } + /// Creates a new handle [`ApiRepo`] which contains operations + /// on a particular [`Repo`] + pub fn repo(&self, repo: Repo) -> ApiRepo{ + ApiRepo::new(self.clone(), repo) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::sync::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Model)); + /// ``` + pub fn model(&self, model_id: String) -> ApiRepo{ + self.repo(Repo::new(model_id, RepoType::Model)) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::sync::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Dataset)); + /// ``` + pub fn dataset(&self, model_id: String) -> ApiRepo{ + self.repo(Repo::new(model_id, RepoType::Dataset)) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::sync::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Space)); + /// ``` + pub fn space(&self, model_id: String) -> ApiRepo{ + self.repo(Repo::new(model_id, RepoType::Space)) + } +} + +/// Shorthand for accessing things within a particular repo +pub struct ApiRepo{ + api: Api, + repo: Repo, +} + +impl ApiRepo{ + fn new(api: Api, repo: Repo) -> Self{ + Self{api, repo} + } +} + + +impl ApiRepo{ + /// Get the fully qualified URL of the remote filename + /// ``` + /// # use hf_hub::api::sync::Api; + /// let api = Api::new().unwrap(); + /// let url = api.model("gpt2".to_string()).url("model.safetensors"); + /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); + /// ``` + pub fn url(&self, filename: &str) -> String { + let endpoint = &self.api.endpoint; + let revision = &self.repo.url_revision(); + self.api.url_template + .replace("{endpoint}", endpoint) + .replace("{repo_id}", &self.repo.url()) + .replace("{revision}", revision) + .replace("{filename}", filename) + } + + /// This will attempt the fetch the file locally first, then [`Api.download`] /// if the file is not present. /// ```no_run - /// use hf_hub::{api::sync::ApiBuilder, Repo}; - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.get(&repo, "model.safetensors").unwrap(); - pub fn get(&self, repo: &Repo, filename: &str) -> Result { - if let Some(path) = self.cache.get(repo, filename) { + /// use hf_hub::{api::sync::Api}; + /// let api = Api::new().unwrap(); + /// let local_filename = api.model("gpt2".to_string()).get("model.safetensors").unwrap(); + pub fn get(&self, filename: &str) -> Result { + if let Some(path) = self.api.cache.get(&self.repo, filename) { Ok(path) } else { - self.download(repo, filename) + self.download(filename) } } @@ -391,19 +445,18 @@ impl Api { /// This functions require internet access to verify if new versions of the file /// exist, even if a file is already on disk at location. /// ```no_run - /// # use hf_hub::{api::sync::ApiBuilder, Repo}; - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.download(&repo, "model.safetensors").unwrap(); + /// # use hf_hub::api::sync::Api; + /// let api = Api::new().unwrap(); + /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap(); /// ``` - pub fn download(&self, repo: &Repo, filename: &str) -> Result { - let url = self.url(repo, filename); - let metadata = self.metadata(&url)?; + pub fn download(&self, filename: &str) -> Result { + let url = self.url(filename); + let metadata = self.api.metadata(&url)?; - let blob_path = self.cache.blob_path(repo, &metadata.etag); + let blob_path = self.api.cache.blob_path(&self.repo, &metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.progress { + let progressbar = if self.api.progress { let progress = ProgressBar::new(metadata.size as u64); progress.set_style( ProgressStyle::with_template( @@ -423,7 +476,7 @@ impl Api { None }; - let tmp_filename = self.download_tempfile(&url, progressbar)?; + let tmp_filename = self.api.download_tempfile(&url, progressbar)?; if std::fs::rename(&tmp_filename, &blob_path).is_err() { // Renaming may fail if locations are different mount points @@ -431,26 +484,25 @@ impl Api { std::fs::copy(tmp_filename, &blob_path)?; } - let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); + let mut pointer_path = self.api.cache.pointer_path(&self.repo, &metadata.commit_hash); pointer_path.push(filename); std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); symlink_or_rename(&blob_path, &pointer_path)?; - self.cache.create_ref(repo, &metadata.commit_hash)?; + self.api.cache.create_ref(&self.repo, &metadata.commit_hash)?; Ok(pointer_path) } /// Get information about the Repo /// ``` - /// use hf_hub::{api::sync::Api, Repo}; + /// use hf_hub::{api::sync::Api}; /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// api.info(&repo); + /// api.model("gpt2".to_string()).info(); /// ``` - pub fn info(&self, repo: &Repo) -> Result { - let url = format!("{}/api/{}", self.endpoint, repo.api_url()); - let response = self.client.get(&url).call().map_err(Box::new)?; + pub fn info(&self) -> Result { + let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url()); + let response = self.api.client.get(&url).call().map_err(Box::new)?; let model_info = response.into_json()?; @@ -499,8 +551,9 @@ mod tests { .with_cache_dir(tmp.path.clone()) .build() .unwrap(); - let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); - let downloaded_path = api.download(&repo, "config.json").unwrap(); + + let model_id = "julien-c/dummy-unknown".to_string(); + let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap(); assert!(downloaded_path.exists()); let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); assert_eq!( @@ -509,7 +562,7 @@ mod tests { ); // Make sure the file is now seeable without connection - let cache_path = api.cache.get(&repo, "config.json").unwrap(); + let cache_path = api.cache.get(&Repo::new(model_id, RepoType::Model), "config.json").unwrap(); assert_eq!(cache_path, downloaded_path); } @@ -526,8 +579,8 @@ mod tests { RepoType::Dataset, "refs/convert/parquet".to_string(), ); - let downloaded_path = api - .download(&repo, "wikitext-103-v1/wikitext-test.parquet") + let downloaded_path = api.repo(repo) + .download("wikitext-103-v1/wikitext-test.parquet") .unwrap(); assert!(downloaded_path.exists()); let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); @@ -550,7 +603,7 @@ mod tests { RepoType::Dataset, "refs/convert/parquet".to_string(), ); - let model_info = api.info(&repo).unwrap(); + let model_info = api.repo(repo).info().unwrap(); assert_eq!( model_info, RepoInfo { diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 9c233ef..e3cde6b 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,5 +1,5 @@ use super::RepoInfo; -use crate::{Cache, Repo}; +use crate::{Cache, Repo, RepoType}; use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use reqwest::{ @@ -189,6 +189,7 @@ struct Metadata { /// The actual Api used to interacto with the hub. /// You can inspect repos with [`Api::info`] /// or download files with [`Api::download`] +#[derive(Clone)] pub struct Api { endpoint: String, url_template: String, @@ -278,24 +279,6 @@ impl Api { ApiBuilder::new().build() } - /// Get the fully qualified URL of the remote filename - /// ``` - /// # use hf_hub::{api::tokio::Api, Repo}; - /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let url = api.url(&repo, "model.safetensors"); - /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); - /// ``` - pub fn url(&self, repo: &Repo, filename: &str) -> String { - let endpoint = &self.endpoint; - let revision = &repo.url_revision(); - self.url_template - .replace("{endpoint}", endpoint) - .replace("{repo_id}", &repo.url()) - .replace("{revision}", revision) - .replace("{filename}", filename) - } - /// Get the underlying api client /// Allows for lower level access pub fn client(&self) -> &Client { @@ -358,6 +341,79 @@ impl Api { }) } + + /// Creates a new handle [`ApiRepo`] which contains operations + /// on a particular [`Repo`] + pub fn repo(&self, repo: Repo) -> ApiRepo{ + ApiRepo::new(self.clone(), repo) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Model)); + /// ``` + pub fn model(&self, model_id: String) -> ApiRepo{ + self.repo(Repo::new(model_id, RepoType::Model)) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Dataset)); + /// ``` + pub fn dataset(&self, model_id: String) -> ApiRepo{ + self.repo(Repo::new(model_id, RepoType::Dataset)) + } + + /// Simple wrapper over + /// ``` + /// # use hf_hub::{api::tokio::Api, Repo, RepoType}; + /// # let model_id = "gpt2".to_string(); + /// let api = Api::new().unwrap(); + /// let api = api.repo(Repo::new(model_id, RepoType::Space)); + /// ``` + pub fn space(&self, model_id: String) -> ApiRepo{ + self.repo(Repo::new(model_id, RepoType::Space)) + } + +} + + +/// Shorthand for accessing things within a particular repo +pub struct ApiRepo{ + api: Api, + repo: Repo, +} + +impl ApiRepo{ + fn new(api: Api, repo: Repo) -> Self{ + Self{api, repo} + } +} + +impl ApiRepo{ + /// Get the fully qualified URL of the remote filename + /// ``` + /// # use hf_hub::api::tokio::Api; + /// let api = Api::new().unwrap(); + /// let url = api.model("gpt2".to_string()).url("model.safetensors"); + /// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors"); + /// ``` + pub fn url(&self, filename: &str) -> String { + let endpoint = &self.api.endpoint; + let revision = &self.repo.url_revision(); + self.api.url_template + .replace("{endpoint}", endpoint) + .replace("{repo_id}", &self.repo.url()) + .replace("{revision}", revision) + .replace("{filename}", filename) + } + async fn download_tempfile( &self, url: &str, @@ -365,8 +421,8 @@ impl Api { progressbar: Option, ) -> Result { let mut handles = vec![]; - let semaphore = Arc::new(Semaphore::new(self.max_files)); - let parallel_failures_semaphore = Arc::new(Semaphore::new(self.parallel_failures)); + let semaphore = Arc::new(Semaphore::new(self.api.max_files)); + let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures)); let filename = temp_filename(); // Create the file and set everything properly @@ -375,16 +431,16 @@ impl Api { .set_len(length as u64) .await?; - let chunk_size = self.chunk_size; + let chunk_size = self.api.chunk_size; for start in (0..length).step_by(chunk_size) { let url = url.to_string(); let filename = filename.clone(); - let client = self.client.clone(); + let client = self.api.client.clone(); let stop = std::cmp::min(start + chunk_size - 1, length); let permit = semaphore.clone().acquire_owned().await?; - let parallel_failures = self.parallel_failures; - let max_retries = self.max_retries; + let parallel_failures = self.api.parallel_failures; + let max_retries = self.api.max_retries; let parallel_failures_semaphore = parallel_failures_semaphore.clone(); let progress = progressbar.clone(); handles.push(tokio::spawn(async move { @@ -454,17 +510,16 @@ impl Api { /// This will attempt the fetch the file locally first, then [`Api.download`] /// if the file is not present. /// ```no_run - /// # use hf_hub::{api::tokio::ApiBuilder, Repo}; + /// # use hf_hub::api::tokio::Api; /// # tokio_test::block_on(async { - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.get(&repo, "model.safetensors").await.unwrap(); + /// let api = Api::new().unwrap(); + /// let local_filename = api.model("gpt2".to_string()).get("model.safetensors").await.unwrap(); /// # }) - pub async fn get(&self, repo: &Repo, filename: &str) -> Result { - if let Some(path) = self.cache.get(repo, filename) { + pub async fn get(&self, filename: &str) -> Result { + if let Some(path) = self.api.cache.get(&self.repo, filename) { Ok(path) } else { - self.download(repo, filename).await + self.download(filename).await } } @@ -473,21 +528,21 @@ impl Api { /// This functions require internet access to verify if new versions of the file /// exist, even if a file is already on disk at location. /// ```no_run - /// # use hf_hub::{api::tokio::ApiBuilder, Repo}; + /// # use hf_hub::api::tokio::Api; /// # tokio_test::block_on(async { - /// let api = ApiBuilder::new().build().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// let local_filename = api.download(&repo, "model.safetensors").await.unwrap(); + /// let api = Api::new().unwrap(); + /// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").await.unwrap(); /// # }) /// ``` - pub async fn download(&self, repo: &Repo, filename: &str) -> Result { - let url = self.url(repo, filename); - let metadata = self.metadata(&url).await?; + pub async fn download(&self, filename: &str) -> Result { + let repo = &self.repo; + let url = self.url(filename); + let metadata = self.api.metadata(&url).await?; - let blob_path = self.cache.blob_path(repo, &metadata.etag); + let blob_path = self.api.cache.blob_path(repo, &metadata.etag); std::fs::create_dir_all(blob_path.parent().unwrap())?; - let progressbar = if self.progress { + let progressbar = if self.api.progress { let progress = ProgressBar::new(metadata.size as u64); progress.set_style( ProgressStyle::with_template( @@ -517,28 +572,28 @@ impl Api { tokio::fs::copy(tmp_filename, &blob_path).await?; } - let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash); + let mut pointer_path = self.api.cache.pointer_path(repo, &metadata.commit_hash); pointer_path.push(filename); std::fs::create_dir_all(pointer_path.parent().unwrap()).ok(); symlink_or_rename(&blob_path, &pointer_path)?; - self.cache.create_ref(repo, &metadata.commit_hash)?; + self.api.cache.create_ref(repo, &metadata.commit_hash)?; Ok(pointer_path) } /// Get information about the Repo /// ``` - /// # use hf_hub::{api::tokio::Api, Repo}; + /// # use hf_hub::api::tokio::Api; /// # tokio_test::block_on(async { /// let api = Api::new().unwrap(); - /// let repo = Repo::model("gpt2".to_string()); - /// api.info(&repo); + /// api.model("gpt2".to_string()).info(); /// # }) /// ``` - pub async fn info(&self, repo: &Repo) -> Result { - let url = format!("{}/api/{}", self.endpoint, repo.api_url()); - let response = self.client.get(url).send().await?; + pub async fn info(&self) -> Result { + let repo = &self.repo; + let url = format!("{}/api/{}", self.api.endpoint, repo.api_url()); + let response = self.api.client.get(url).send().await?; let response = response.error_for_status()?; let model_info = response.json().await?; @@ -588,8 +643,9 @@ mod tests { .with_cache_dir(tmp.path.clone()) .build() .unwrap(); - let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model); - let downloaded_path = api.download(&repo, "config.json").await.unwrap(); + let model_id = "julien-c/dummy-unknown".to_string(); + let repo = Repo::new(model_id.clone(), RepoType::Model); + let downloaded_path = api.model(model_id).download("config.json").await.unwrap(); assert!(downloaded_path.exists()); let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap()); assert_eq!( @@ -615,8 +671,8 @@ mod tests { RepoType::Dataset, "refs/convert/parquet".to_string(), ); - let downloaded_path = api - .download(&repo, "wikitext-103-v1/wikitext-test.parquet") + let downloaded_path = api.repo(repo) + .download("wikitext-103-v1/wikitext-test.parquet") .await .unwrap(); assert!(downloaded_path.exists()); @@ -640,7 +696,7 @@ mod tests { RepoType::Dataset, "refs/convert/parquet".to_string(), ); - let model_info = api.info(&repo).await.unwrap(); + let model_info = api.repo(repo).info().await.unwrap(); assert_eq!( model_info, RepoInfo { diff --git a/src/lib.rs b/src/lib.rs index 29ca708..f4c7e6f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,7 @@ pub enum RepoType { } /// A local struct used to fetch information from the cache folder. +#[derive(Clone)] pub struct Cache { path: PathBuf, } @@ -201,11 +202,28 @@ mod tests { use super::*; #[test] + #[cfg(not(target_os="windows"))] fn token_path() { let cache = Cache::default(); let token_path = cache.token_path().to_str().unwrap().to_string(); - let n = "huggingface/token".len(); + if let Ok(hf_home) = std::env::var("HF_HOME"){ + assert_eq!(token_path, format!("{hf_home}/token")); + }else{ + let n = "huggingface/token".len(); + assert_eq!(&token_path[token_path.len() - n..], "huggingface/token"); + } + } - assert_eq!(&token_path[token_path.len() - n..], "huggingface/token"); + #[test] + #[cfg(target_os="windows")] + fn token_path() { + let cache = Cache::default(); + let token_path = cache.token_path().to_str().unwrap().to_string(); + if let Ok(hf_home) = std::env::var("HF_HOME"){ + assert_eq!(token_path, format!("{hf_home}\\token")); + }else{ + let n = "huggingface/token".len(); + assert_eq!(&token_path[token_path.len() - n..], "huggingface\\token"); + } } }