Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,36 @@ it should be consistent with [huggingface_hub](https://github.com/huggingface/hu

At this time only a limited subset of the functionality is present, the goal is to add new
features over time. We are currently treating this as an internel/external tool, meaning
we will treat what exists as public, and keep backward compatibility in the same regard.
we will are currently modifying everything at will for out internal needs. This will eventually
stabilize as it matures to accomodate most of our needs.

If you're interested in using this, you're welcome to do it but be warned about potential changing grounds.

If you want to contribute, you are more than welcome.

However allowing new features or creating new features might be denied by lack of maintainability
time. We're focusing on what we currently internally need. Hopefully that subset is already interesting
to more users.


# How to use

Add the dependency

```bash
cargo add hf-hub # --features tokio
```
`tokio` feature will enable an async (and potentially faster) API.

Use the crate:

```rust
use hf_hub::api::sync::Api;

let api = Api::new().unwrap();

let repo = api.model("bert-base-uncased".to_string());
let _filename = repo.get("config.json").unwrap();

// filename is now the local location within hf cache of the config.json file
```
11 changes: 2 additions & 9 deletions examples/download.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
use hf_hub::{Repo, RepoType};

#[cfg(not(feature = "tokio"))]
fn main() {
let api = hf_hub::api::sync::Api::new().unwrap();
let repo = Repo::new("meta-llama/Llama-2-7b-hf".to_string(), RepoType::Model);

let _filename = api.get(&repo, "model-00001-of-00002.safetensors").unwrap();
let _filename = api.model("meta-llama/Llama-2-7b-hf".to_string()).get("model-00001-of-00002.safetensors").unwrap();
}

#[cfg(feature = "tokio")]
#[tokio::main]
async fn main() {
let api = hf_hub::api::tokio::Api::new().unwrap();
let repo = Repo::new("meta-llama/Llama-2-7b-hf".to_string(), RepoType::Model);

let _filename = api
.get(&repo, "model-00001-of-00002.safetensors")
.await
.unwrap();
let _filename = api.model("meta-llama/Llama-2-7b-hf".to_string()).get("model-00001-of-00002.safetensors").await.unwrap();
}
153 changes: 103 additions & 50 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Cache, Repo};
use crate::{Cache, Repo, RepoType};
use indicatif::{ProgressBar, ProgressStyle};
use rand::{distributions::Alphanumeric, Rng};
use std::collections::HashMap;
Expand Down Expand Up @@ -188,6 +188,7 @@ struct Metadata {
/// The actual Api used to interacto with the hub.
/// You can inspect repos with [`Api::info`]
/// or download files with [`Api::download`]
#[derive(Clone)]
pub struct Api {
endpoint: String,
url_template: String,
Expand Down Expand Up @@ -265,24 +266,6 @@ impl Api {
ApiBuilder::new().build()
}

/// Get the fully qualified URL of the remote filename
/// ```
/// # use hf_hub::{api::sync::Api, Repo};
/// let api = Api::new().unwrap();
/// let repo = Repo::model("gpt2".to_string());
/// let url = api.url(&repo, "model.safetensors");
/// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors");
/// ```
pub fn url(&self, repo: &Repo, filename: &str) -> String {
let endpoint = &self.endpoint;
let revision = &repo.url_revision();
self.url_template
.replace("{endpoint}", endpoint)
.replace("{repo_id}", &repo.url())
.replace("{revision}", revision)
.replace("{filename}", filename)
}

/// Get the underlying api client
/// Allows for lower level access
pub fn client(&self) -> &HeaderAgent {
Expand Down Expand Up @@ -371,18 +354,89 @@ impl Api {
Ok(filename)
}

/// Creates a new handle [`ApiRepo`] which contains operations
/// on a particular [`Repo`]
pub fn repo(&self, repo: Repo) -> ApiRepo{
ApiRepo::new(self.clone(), repo)
}

/// Simple wrapper over
/// ```
/// # use hf_hub::{api::sync::Api, Repo, RepoType};
/// # let model_id = "gpt2".to_string();
/// let api = Api::new().unwrap();
/// let api = api.repo(Repo::new(model_id, RepoType::Model));
/// ```
pub fn model(&self, model_id: String) -> ApiRepo{
self.repo(Repo::new(model_id, RepoType::Model))
}

/// Simple wrapper over
/// ```
/// # use hf_hub::{api::sync::Api, Repo, RepoType};
/// # let model_id = "gpt2".to_string();
/// let api = Api::new().unwrap();
/// let api = api.repo(Repo::new(model_id, RepoType::Dataset));
/// ```
pub fn dataset(&self, model_id: String) -> ApiRepo{
self.repo(Repo::new(model_id, RepoType::Dataset))
}

/// Simple wrapper over
/// ```
/// # use hf_hub::{api::sync::Api, Repo, RepoType};
/// # let model_id = "gpt2".to_string();
/// let api = Api::new().unwrap();
/// let api = api.repo(Repo::new(model_id, RepoType::Space));
/// ```
pub fn space(&self, model_id: String) -> ApiRepo{
self.repo(Repo::new(model_id, RepoType::Space))
}
}

/// Shorthand for accessing things within a particular repo
pub struct ApiRepo{
api: Api,
repo: Repo,
}

impl ApiRepo{
fn new(api: Api, repo: Repo) -> Self{
Self{api, repo}
}
}


impl ApiRepo{
/// Get the fully qualified URL of the remote filename
/// ```
/// # use hf_hub::api::sync::Api;
/// let api = Api::new().unwrap();
/// let url = api.model("gpt2".to_string()).url("model.safetensors");
/// assert_eq!(url, "https://huggingface.co/gpt2/resolve/main/model.safetensors");
/// ```
pub fn url(&self, filename: &str) -> String {
let endpoint = &self.api.endpoint;
let revision = &self.repo.url_revision();
self.api.url_template
.replace("{endpoint}", endpoint)
.replace("{repo_id}", &self.repo.url())
.replace("{revision}", revision)
.replace("{filename}", filename)
}


/// This will attempt the fetch the file locally first, then [`Api.download`]
/// if the file is not present.
/// ```no_run
/// use hf_hub::{api::sync::ApiBuilder, Repo};
/// let api = ApiBuilder::new().build().unwrap();
/// let repo = Repo::model("gpt2".to_string());
/// let local_filename = api.get(&repo, "model.safetensors").unwrap();
pub fn get(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
if let Some(path) = self.cache.get(repo, filename) {
/// use hf_hub::{api::sync::Api};
/// let api = Api::new().unwrap();
/// let local_filename = api.model("gpt2".to_string()).get("model.safetensors").unwrap();
pub fn get(&self, filename: &str) -> Result<PathBuf, ApiError> {
if let Some(path) = self.api.cache.get(&self.repo, filename) {
Ok(path)
} else {
self.download(repo, filename)
self.download(filename)
}
}

Expand All @@ -391,19 +445,18 @@ impl Api {
/// This functions require internet access to verify if new versions of the file
/// exist, even if a file is already on disk at location.
/// ```no_run
/// # use hf_hub::{api::sync::ApiBuilder, Repo};
/// let api = ApiBuilder::new().build().unwrap();
/// let repo = Repo::model("gpt2".to_string());
/// let local_filename = api.download(&repo, "model.safetensors").unwrap();
/// # use hf_hub::api::sync::Api;
/// let api = Api::new().unwrap();
/// let local_filename = api.model("gpt2".to_string()).download("model.safetensors").unwrap();
/// ```
pub fn download(&self, repo: &Repo, filename: &str) -> Result<PathBuf, ApiError> {
let url = self.url(repo, filename);
let metadata = self.metadata(&url)?;
pub fn download(&self, filename: &str) -> Result<PathBuf, ApiError> {
let url = self.url(filename);
let metadata = self.api.metadata(&url)?;

let blob_path = self.cache.blob_path(repo, &metadata.etag);
let blob_path = self.api.cache.blob_path(&self.repo, &metadata.etag);
std::fs::create_dir_all(blob_path.parent().unwrap())?;

let progressbar = if self.progress {
let progressbar = if self.api.progress {
let progress = ProgressBar::new(metadata.size as u64);
progress.set_style(
ProgressStyle::with_template(
Expand All @@ -423,34 +476,33 @@ impl Api {
None
};

let tmp_filename = self.download_tempfile(&url, progressbar)?;
let tmp_filename = self.api.download_tempfile(&url, progressbar)?;

if std::fs::rename(&tmp_filename, &blob_path).is_err() {
// Renaming may fail if locations are different mount points
std::fs::File::create(&blob_path)?;
std::fs::copy(tmp_filename, &blob_path)?;
}

let mut pointer_path = self.cache.pointer_path(repo, &metadata.commit_hash);
let mut pointer_path = self.api.cache.pointer_path(&self.repo, &metadata.commit_hash);
pointer_path.push(filename);
std::fs::create_dir_all(pointer_path.parent().unwrap()).ok();

symlink_or_rename(&blob_path, &pointer_path)?;
self.cache.create_ref(repo, &metadata.commit_hash)?;
self.api.cache.create_ref(&self.repo, &metadata.commit_hash)?;

Ok(pointer_path)
}

/// Get information about the Repo
/// ```
/// use hf_hub::{api::sync::Api, Repo};
/// use hf_hub::{api::sync::Api};
/// let api = Api::new().unwrap();
/// let repo = Repo::model("gpt2".to_string());
/// api.info(&repo);
/// api.model("gpt2".to_string()).info();
/// ```
pub fn info(&self, repo: &Repo) -> Result<RepoInfo, ApiError> {
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
let response = self.client.get(&url).call().map_err(Box::new)?;
pub fn info(&self) -> Result<RepoInfo, ApiError> {
let url = format!("{}/api/{}", self.api.endpoint, self.repo.api_url());
let response = self.api.client.get(&url).call().map_err(Box::new)?;

let model_info = response.into_json()?;

Expand Down Expand Up @@ -499,8 +551,9 @@ mod tests {
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::new("julien-c/dummy-unknown".to_string(), RepoType::Model);
let downloaded_path = api.download(&repo, "config.json").unwrap();

let model_id = "julien-c/dummy-unknown".to_string();
let downloaded_path = api.model(model_id.clone()).download("config.json").unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
assert_eq!(
Expand All @@ -509,7 +562,7 @@ mod tests {
);

// Make sure the file is now seeable without connection
let cache_path = api.cache.get(&repo, "config.json").unwrap();
let cache_path = api.cache.get(&Repo::new(model_id, RepoType::Model), "config.json").unwrap();
assert_eq!(cache_path, downloaded_path);
}

Expand All @@ -526,8 +579,8 @@ mod tests {
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let downloaded_path = api
.download(&repo, "wikitext-103-v1/wikitext-test.parquet")
let downloaded_path = api.repo(repo)
.download("wikitext-103-v1/wikitext-test.parquet")
.unwrap();
assert!(downloaded_path.exists());
let val = Sha256::digest(std::fs::read(&*downloaded_path).unwrap());
Expand All @@ -550,7 +603,7 @@ mod tests {
RepoType::Dataset,
"refs/convert/parquet".to_string(),
);
let model_info = api.info(&repo).unwrap();
let model_info = api.repo(repo).info().unwrap();
assert_eq!(
model_info,
RepoInfo {
Expand Down
Loading