Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
use serde::Deserialize;

/// The asynchronous version of the API
#[cfg(feature = "tokio")]
pub mod tokio;

/// The synchronous version of the API
pub mod sync;

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
/// The path within the repo.
pub rfilename: String,
}

/// The description of a repo given by the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct RepoInfo {
/// Git commit sha
pub sha: String,
/// See [`Siblings`]
pub siblings: Vec<Siblings>,
}

#[derive(Debug, Deserialize)]
pub(crate) struct RepoSha {
pub sha: String,
}
105 changes: 63 additions & 42 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::HashMap;
// redirect::Policy,
// Error as ReqwestError,
// };
use serde::Deserialize;
use crate::api::{RepoInfo, RepoSha};
use std::io::{Seek, SeekFrom, Write};
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
Expand All @@ -33,19 +33,19 @@ type HeaderName = &'static str;

/// Simple wrapper over [`ureq::Agent`] to include default headers
#[derive(Clone)]
pub struct HeaderAgent{
pub struct HeaderAgent {
agent: Agent,
headers: HeaderMap,
}

impl HeaderAgent{
fn new(agent: Agent, headers:HeaderMap) -> Self{
Self{agent, headers}
impl HeaderAgent {
fn new(agent: Agent, headers: HeaderMap) -> Self {
Self { agent, headers }
}

fn get(&self, url: &str) -> ureq::Request{
fn get(&self, url: &str) -> ureq::Request {
let mut request = self.agent.get(url);
for (header, value) in &self.headers{
for (header, value) in &self.headers {
request = request.set(header, &value);
}
request
Expand All @@ -70,7 +70,6 @@ pub enum ApiError {
// /// The header value is not valid utf-8
// #[error("header value is not a string")]
// ToStr(#[from] ToStrError),

/// Error in the request
#[error("request error: {0}")]
RequestError(#[from] ureq::Error),
Expand All @@ -88,20 +87,6 @@ pub enum ApiError {
TooManyRetries(Box<ApiError>),
}

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
/// The path within the repo.
pub rfilename: String,
}

/// The description of the repo given by the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct ModelInfo {
/// See [`Siblings`]
pub siblings: Vec<Siblings>,
}

/// Helper to create [`Api`] with all the options.
pub struct ApiBuilder {
endpoint: String,
Expand Down Expand Up @@ -179,10 +164,7 @@ impl ApiBuilder {
let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown");
headers.insert(USER_AGENT, user_agent);
if let Some(token) = &self.token {
headers.insert(
AUTHORIZATION,
format!("Bearer {token}"),
);
headers.insert(AUTHORIZATION, format!("Bearer {token}"));
}
Ok(headers)
}
Expand All @@ -191,9 +173,7 @@ impl ApiBuilder {
pub fn build(self) -> Result<Api, ApiError> {
let headers = self.build_headers()?;
let client = HeaderAgent::new(ureq::builder().build(), headers.clone());
let no_redirect_client = HeaderAgent::new(ureq::builder()
.redirects(0)
.build(), headers);
let no_redirect_client = HeaderAgent::new(ureq::builder().redirects(0).build(), headers);
Ok(Api {
endpoint: self.endpoint,
url_template: self.url_template,
Expand Down Expand Up @@ -456,10 +436,7 @@ impl Api {
let range = format!("bytes={start}-{stop}");
let mut file = std::fs::OpenOptions::new().write(true).open(filename)?;
file.seek(SeekFrom::Start(start as u64))?;
let response = client
.get(url)
.set(RANGE, &range)
.call()?;
let response = client.get(url).set(RANGE, &range).call()?;

const MAX: usize = 4096;
let mut buffer: [u8; MAX] = [0; MAX];
Expand Down Expand Up @@ -553,25 +530,50 @@ impl Api {
/// let repo = Repo::model("gpt2".to_string());
/// api.info(&repo);
/// ```
pub fn info(&self, repo: &Repo) -> Result<ModelInfo, ApiError> {
pub fn info(&self, repo: &Repo) -> Result<RepoInfo, ApiError> {
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
let response = self.client.get(&url).call()?;

let model_info = response.into_json()?;
let repo_info = response.into_json()?;

Ok(model_info)
Ok(repo_info)
}
}

/// Check if the [`Repo`]'s revision is the latest commit rev on main
/// ```
/// use hf_hub::{api::sync::Api, Repo, RepoType};
/// let api = Api::new().unwrap();
/// let repo = Repo::with_revision(
/// "mcpotato/42".to_owned(),
/// RepoType::Model,
/// "b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(),
/// );
/// if api
/// .is_main(&repo)
/// .expect("api call to go through successfully")
/// {
/// println!(
/// "repo's revision ({}) is the latest on main",
/// repo.revision()
/// );
/// }
/// ```
pub fn is_main(&self, repo: &Repo) -> Result<bool, ApiError> {
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
let response = self.client.get(&url).query("expand[]", "sha").call()?;
let repo_sha: RepoSha = response.into_json()?;

Ok(repo_sha.sha == repo.revision)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::RepoType;
use crate::{api::Siblings, RepoType};
use hex_literal::hex;
use rand::{distributions::Alphanumeric, Rng};
use sha2::{Digest, Sha256};
use hex_literal::hex;


struct TempDir {
path: PathBuf,
Expand Down Expand Up @@ -654,12 +656,13 @@ mod tests {
let repo = Repo::with_revision(
"wikitext".to_string(),
RepoType::Dataset,
"refs/convert/parquet".to_string(),
"2dd3f79917d431e9af1c81bfa96a575741774077".to_string(),
);
let model_info = api.info(&repo).unwrap();
assert_eq!(
model_info,
ModelInfo {
RepoInfo {
sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_owned(),
siblings: vec![
Siblings {
rfilename: ".gitattributes".to_string()
Expand Down Expand Up @@ -729,4 +732,22 @@ mod tests {
}
)
}

/// XXX: this test may break eventually,
/// I'll choose a repo that shouldn't receive commits
#[test]
fn is_main() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"mcpotato/42".to_owned(),
RepoType::Model,
"b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(),
);
assert!(api.is_main(&repo).unwrap());
}
}
82 changes: 62 additions & 20 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::api::{RepoInfo, RepoSha};
use crate::{Cache, Repo};
use indicatif::{ProgressBar, ProgressStyle};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
Expand All @@ -9,7 +10,6 @@ use reqwest::{
redirect::Policy,
Client, Error as ReqwestError,
};
use serde::Deserialize;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
Expand Down Expand Up @@ -69,20 +69,6 @@ pub enum ApiError {
// InvalidResponse(Response),
}

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
/// The path within the repo.
pub rfilename: String,
}

/// The description of the repo given by the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct ModelInfo {
/// See [`Siblings`]
pub siblings: Vec<Siblings>,
}

/// Helper to create [`Api`] with all the options.
pub struct ApiBuilder {
endpoint: String,
Expand Down Expand Up @@ -551,7 +537,7 @@ impl Api {
/// api.info(&repo);
/// # })
/// ```
pub async fn info(&self, repo: &Repo) -> Result<ModelInfo, ApiError> {
pub async fn info(&self, repo: &Repo) -> Result<RepoInfo, ApiError> {
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
let response = self.client.get(url).send().await?;
let response = response.error_for_status()?;
Expand All @@ -560,15 +546,52 @@ impl Api {

Ok(model_info)
}

/// Check if the [`Repo`]'s revision is the latest commit rev on main
/// ```no_run
/// # use hf_hub::{api::tokio::ApiBuilder, Repo, RepoType};
/// # tokio_test::block_on(async {
/// let api = ApiBuilder::new().build().unwrap();
/// let repo = Repo::with_revision(
/// "mcpotato/42".to_owned(),
/// RepoType::Model,
/// "b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(),
/// );
/// if api
/// .is_main(&repo)
/// .await
/// .expect("api call to go through successfully")
/// {
/// println!(
/// "repo's revision ({}) is the latest on main",
/// repo.revision()
/// );
/// }
/// # })
/// ```
pub async fn is_main(&self, repo: &Repo) -> Result<bool, ApiError> {
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
let repo_sha: RepoSha = self
.client
.get(&url)
.query(&[("expand[]", "sha")])
.send()
.await?
.error_for_status()?
.json()
.await?;

Ok(repo_sha.sha == repo.revision)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::RepoType;
use crate::{api::Siblings, RepoType};
use hex_literal::hex;
use rand::{distributions::Alphanumeric, Rng};
use sha2::{Digest, Sha256};
use hex_literal::hex;

struct TempDir {
path: PathBuf,
Expand Down Expand Up @@ -652,12 +675,13 @@ mod tests {
let repo = Repo::with_revision(
"wikitext".to_string(),
RepoType::Dataset,
"refs/convert/parquet".to_string(),
"2dd3f79917d431e9af1c81bfa96a575741774077".to_string(),
);
let model_info = api.info(&repo).await.unwrap();
assert_eq!(
model_info,
ModelInfo {
RepoInfo {
sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_owned(),
siblings: vec![
Siblings {
rfilename: ".gitattributes".to_string()
Expand Down Expand Up @@ -727,4 +751,22 @@ mod tests {
}
)
}

/// XXX: this test may break eventually,
/// I'll choose a repo that shouldn't receive commits
#[tokio::test]
async fn is_main() {
let tmp = TempDir::new();
let api = ApiBuilder::new()
.with_progress(false)
.with_cache_dir(tmp.path.clone())
.build()
.unwrap();
let repo = Repo::with_revision(
"mcpotato/42".to_owned(),
RepoType::Model,
"b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(),
);
assert!(api.is_main(&repo).await.unwrap());
}
}