diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index b42ebb1..c94a6e8 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -26,7 +26,7 @@ jobs: run: cargo build --all-targets --verbose - name: Lint with Clippy - run: cargo clippy --all-targets --all-features -- -D warnings + run: cargo clippy --all-targets --all-features --tests --examples -- -D warnings - name: Run Tests run: cargo test --all-features --verbose diff --git a/examples/download.rs b/examples/download.rs index 4ca4970..dd6799f 100644 --- a/examples/download.rs +++ b/examples/download.rs @@ -2,7 +2,10 @@ fn main() { let api = hf_hub::api::sync::Api::new().unwrap(); - let _filename = api.model("meta-llama/Llama-2-7b-hf".to_string()).get("model-00001-of-00002.safetensors").unwrap(); + let _filename = api + .model("meta-llama/Llama-2-7b-hf".to_string()) + .get("model-00001-of-00002.safetensors") + .unwrap(); } #[cfg(feature = "tokio")] @@ -10,5 +13,9 @@ fn main() { async fn main() { let api = hf_hub::api::tokio::Api::new().unwrap(); - let _filename = api.model("meta-llama/Llama-2-7b-hf".to_string()).get("model-00001-of-00002.safetensors").await.unwrap(); + let _filename = api + .model("meta-llama/Llama-2-7b-hf".to_string()) + .get("model-00001-of-00002.safetensors") + .await + .unwrap(); } diff --git a/src/api/sync.rs b/src/api/sync.rs index 66c659e..f004e71 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -1,6 +1,5 @@ use crate::{Cache, Repo, RepoType}; use indicatif::{ProgressBar, ProgressStyle}; -use rand::{distributions::Alphanumeric, Rng}; use std::collections::HashMap; // use reqwest::{ // blocking::Agent, @@ -198,17 +197,6 @@ pub struct Api { progress: bool, } -fn temp_filename() -> PathBuf { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - path -} - fn make_relative(src: &Path, dst: &Path) -> PathBuf { let path = src; let base = dst; @@ -331,7 +319,7 @@ impl Api { url: &str, progressbar: Option, ) -> Result { - let filename = temp_filename(); + let filename = self.cache.temp_path(); // Create the file and set everything properly let mut file = std::fs::File::create(&filename)?; @@ -474,11 +462,7 @@ impl ApiRepo { let tmp_filename = self.api.download_tempfile(&url, progressbar)?; - if std::fs::rename(&tmp_filename, &blob_path).is_err() { - // Renaming may fail if locations are different mount points - std::fs::File::create(&blob_path)?; - std::fs::copy(tmp_filename, &blob_path)?; - } + std::fs::rename(tmp_filename, &blob_path)?; let mut pointer_path = self .api @@ -622,7 +606,7 @@ mod tests { assert_eq!( model_info, RepoInfo { - sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(), + sha: "06ac3f4b846ef171cae5a48a35c3e85f2b44f636".to_string(), siblings: vec![ Siblings { rfilename: ".gitattributes".to_string(), diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 6d6bfaf..584b694 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,7 +1,7 @@ use super::RepoInfo; use crate::{Cache, Repo, RepoType}; use indicatif::{ProgressBar, ProgressStyle}; -use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use rand::Rng; use reqwest::{ header::{ HeaderMap, HeaderName, HeaderValue, InvalidHeaderValue, ToStrError, AUTHORIZATION, @@ -203,17 +203,6 @@ pub struct Api { progress: bool, } -fn temp_filename() -> PathBuf { - let s: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(7) - .map(char::from) - .collect(); - let mut path = std::env::temp_dir(); - path.push(s); - path -} - fn make_relative(src: &Path, dst: &Path) -> PathBuf { let path = src; let base = dst; @@ -266,7 +255,7 @@ fn symlink_or_rename(src: &Path, dst: &Path) -> Result<(), std::io::Error> { } fn jitter() -> usize { - thread_rng().gen_range(0..=500) + rand::thread_rng().gen_range(0..=500) } fn exponential_backoff(base_wait_time: usize, n: usize, max: usize) -> usize { @@ -421,7 +410,7 @@ impl ApiRepo { let mut handles = vec![]; let semaphore = Arc::new(Semaphore::new(self.api.max_files)); let parallel_failures_semaphore = Arc::new(Semaphore::new(self.api.parallel_failures)); - let filename = temp_filename(); + let filename = self.api.cache.temp_path(); // Create the file and set everything properly tokio::fs::File::create(&filename) @@ -564,11 +553,7 @@ impl ApiRepo { .download_tempfile(&url, metadata.size, progressbar) .await?; - if tokio::fs::rename(&tmp_filename, &blob_path).await.is_err() { - // Renaming may fail if locations are different mount points - std::fs::File::create(&blob_path)?; - tokio::fs::copy(tmp_filename, &blob_path).await?; - } + tokio::fs::rename(&tmp_filename, &blob_path).await?; let mut pointer_path = self.api.cache.pointer_path(repo, &metadata.commit_hash); pointer_path.push(filename); @@ -710,7 +695,7 @@ mod tests { assert_eq!( model_info, RepoInfo { - sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(), + sha: "06ac3f4b846ef171cae5a48a35c3e85f2b44f636".to_string(), siblings: vec![ Siblings { rfilename: ".gitattributes".to_string() diff --git a/src/lib.rs b/src/lib.rs index f4c7e6f..323d394 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,6 @@ #![deny(missing_docs)] #![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))] +use rand::{distributions::Alphanumeric, Rng}; use std::io::Write; use std::path::PathBuf; @@ -94,6 +95,20 @@ impl Cache { path.push("token"); path } + + pub(crate) fn temp_path(&self) -> PathBuf { + let mut path = self.path.clone(); + path.push("tmp"); + std::fs::create_dir_all(&path).ok(); + + let s: String = rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(7) + .map(char::from) + .collect(); + path.push(s); + path + } } impl Default for Cache { @@ -202,26 +217,26 @@ mod tests { use super::*; #[test] - #[cfg(not(target_os="windows"))] + #[cfg(not(target_os = "windows"))] fn token_path() { let cache = Cache::default(); let token_path = cache.token_path().to_str().unwrap().to_string(); - if let Ok(hf_home) = std::env::var("HF_HOME"){ + if let Ok(hf_home) = std::env::var("HF_HOME") { assert_eq!(token_path, format!("{hf_home}/token")); - }else{ + } else { let n = "huggingface/token".len(); assert_eq!(&token_path[token_path.len() - n..], "huggingface/token"); } } #[test] - #[cfg(target_os="windows")] + #[cfg(target_os = "windows")] fn token_path() { let cache = Cache::default(); let token_path = cache.token_path().to_str().unwrap().to_string(); - if let Ok(hf_home) = std::env::var("HF_HOME"){ + if let Ok(hf_home) = std::env::var("HF_HOME") { assert_eq!(token_path, format!("{hf_home}\\token")); - }else{ + } else { let n = "huggingface/token".len(); assert_eq!(&token_path[token_path.len() - n..], "huggingface\\token"); }