From 6340bd892c7fd33c7e7daa93f90e040e45a6c00f Mon Sep 17 00:00:00 2001 From: Luc Georges Date: Fri, 21 Jul 2023 16:12:09 +0200 Subject: [PATCH] feat(api): add `is_main()` --- src/api/mod.rs | 23 +++++++++++ src/api/sync.rs | 105 ++++++++++++++++++++++++++++------------------- src/api/tokio.rs | 82 +++++++++++++++++++++++++++--------- 3 files changed, 148 insertions(+), 62 deletions(-) diff --git a/src/api/mod.rs b/src/api/mod.rs index 779dc4f..a10ba93 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,6 +1,29 @@ +use serde::Deserialize; + /// The asynchronous version of the API #[cfg(feature = "tokio")] pub mod tokio; /// The synchronous version of the API pub mod sync; + +/// Siblings are simplified file descriptions of remote files on the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct Siblings { + /// The path within the repo. + pub rfilename: String, +} + +/// The description of a repo given by the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct RepoInfo { + /// Git commit sha + pub sha: String, + /// See [`Siblings`] + pub siblings: Vec, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct RepoSha { + pub sha: String, +} diff --git a/src/api/sync.rs b/src/api/sync.rs index 8a25985..453a3ba 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; // redirect::Policy, // Error as ReqwestError, // }; -use serde::Deserialize; +use crate::api::{RepoInfo, RepoSha}; use std::io::{Seek, SeekFrom, Write}; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; @@ -33,19 +33,19 @@ type HeaderName = &'static str; /// Simple wrapper over [`ureq::Agent`] to include default headers #[derive(Clone)] -pub struct HeaderAgent{ +pub struct HeaderAgent { agent: Agent, headers: HeaderMap, } -impl HeaderAgent{ - fn new(agent: Agent, headers:HeaderMap) -> Self{ - Self{agent, headers} +impl HeaderAgent { + fn new(agent: Agent, headers: HeaderMap) -> Self { + Self { agent, headers } } - fn get(&self, url: &str) -> ureq::Request{ + fn get(&self, url: &str) -> ureq::Request { let mut request = self.agent.get(url); - for (header, value) in &self.headers{ + for (header, value) in &self.headers { request = request.set(header, &value); } request @@ -70,7 +70,6 @@ pub enum ApiError { // /// The header value is not valid utf-8 // #[error("header value is not a string")] // ToStr(#[from] ToStrError), - /// Error in the request #[error("request error: {0}")] RequestError(#[from] ureq::Error), @@ -88,20 +87,6 @@ pub enum ApiError { TooManyRetries(Box), } -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - /// Helper to create [`Api`] with all the options. pub struct ApiBuilder { endpoint: String, @@ -179,10 +164,7 @@ impl ApiBuilder { let user_agent = format!("unkown/None; {NAME}/{VERSION}; rust/unknown"); headers.insert(USER_AGENT, user_agent); if let Some(token) = &self.token { - headers.insert( - AUTHORIZATION, - format!("Bearer {token}"), - ); + headers.insert(AUTHORIZATION, format!("Bearer {token}")); } Ok(headers) } @@ -191,9 +173,7 @@ impl ApiBuilder { pub fn build(self) -> Result { let headers = self.build_headers()?; let client = HeaderAgent::new(ureq::builder().build(), headers.clone()); - let no_redirect_client = HeaderAgent::new(ureq::builder() - .redirects(0) - .build(), headers); + let no_redirect_client = HeaderAgent::new(ureq::builder().redirects(0).build(), headers); Ok(Api { endpoint: self.endpoint, url_template: self.url_template, @@ -456,10 +436,7 @@ impl Api { let range = format!("bytes={start}-{stop}"); let mut file = std::fs::OpenOptions::new().write(true).open(filename)?; file.seek(SeekFrom::Start(start as u64))?; - let response = client - .get(url) - .set(RANGE, &range) - .call()?; + let response = client.get(url).set(RANGE, &range).call()?; const MAX: usize = 4096; let mut buffer: [u8; MAX] = [0; MAX]; @@ -553,25 +530,50 @@ impl Api { /// let repo = Repo::model("gpt2".to_string()); /// api.info(&repo); /// ``` - pub fn info(&self, repo: &Repo) -> Result { + pub fn info(&self, repo: &Repo) -> Result { let url = format!("{}/api/{}", self.endpoint, repo.api_url()); let response = self.client.get(&url).call()?; - let model_info = response.into_json()?; + let repo_info = response.into_json()?; - Ok(model_info) + Ok(repo_info) } -} + /// Check if the [`Repo`]'s revision is the latest commit rev on main + /// ``` + /// use hf_hub::{api::sync::Api, Repo, RepoType}; + /// let api = Api::new().unwrap(); + /// let repo = Repo::with_revision( + /// "mcpotato/42".to_owned(), + /// RepoType::Model, + /// "b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(), + /// ); + /// if api + /// .is_main(&repo) + /// .expect("api call to go through successfully") + /// { + /// println!( + /// "repo's revision ({}) is the latest on main", + /// repo.revision() + /// ); + /// } + /// ``` + pub fn is_main(&self, repo: &Repo) -> Result { + let url = format!("{}/api/{}", self.endpoint, repo.api_url()); + let response = self.client.get(&url).query("expand[]", "sha").call()?; + let repo_sha: RepoSha = response.into_json()?; + + Ok(repo_sha.sha == repo.revision) + } +} #[cfg(test)] mod tests { use super::*; - use crate::RepoType; + use crate::{api::Siblings, RepoType}; + use hex_literal::hex; use rand::{distributions::Alphanumeric, Rng}; use sha2::{Digest, Sha256}; - use hex_literal::hex; - struct TempDir { path: PathBuf, @@ -654,12 +656,13 @@ mod tests { let repo = Repo::with_revision( "wikitext".to_string(), RepoType::Dataset, - "refs/convert/parquet".to_string(), + "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(), ); let model_info = api.info(&repo).unwrap(); assert_eq!( model_info, - ModelInfo { + RepoInfo { + sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_owned(), siblings: vec![ Siblings { rfilename: ".gitattributes".to_string() @@ -729,4 +732,22 @@ mod tests { } ) } + + /// XXX: this test may break eventually, + /// I'll choose a repo that shouldn't receive commits + #[test] + fn is_main() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let repo = Repo::with_revision( + "mcpotato/42".to_owned(), + RepoType::Model, + "b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(), + ); + assert!(api.is_main(&repo).unwrap()); + } } diff --git a/src/api/tokio.rs b/src/api/tokio.rs index 6877e71..0a8c05d 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,3 +1,4 @@ +use crate::api::{RepoInfo, RepoSha}; use crate::{Cache, Repo}; use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; @@ -9,7 +10,6 @@ use reqwest::{ redirect::Policy, Client, Error as ReqwestError, }; -use serde::Deserialize; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; use std::sync::Arc; @@ -69,20 +69,6 @@ pub enum ApiError { // InvalidResponse(Response), } -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - /// Helper to create [`Api`] with all the options. pub struct ApiBuilder { endpoint: String, @@ -551,7 +537,7 @@ impl Api { /// api.info(&repo); /// # }) /// ``` - pub async fn info(&self, repo: &Repo) -> Result { + pub async fn info(&self, repo: &Repo) -> Result { let url = format!("{}/api/{}", self.endpoint, repo.api_url()); let response = self.client.get(url).send().await?; let response = response.error_for_status()?; @@ -560,15 +546,52 @@ impl Api { Ok(model_info) } + + /// Check if the [`Repo`]'s revision is the latest commit rev on main + /// ```no_run + /// # use hf_hub::{api::tokio::ApiBuilder, Repo, RepoType}; + /// # tokio_test::block_on(async { + /// let api = ApiBuilder::new().build().unwrap(); + /// let repo = Repo::with_revision( + /// "mcpotato/42".to_owned(), + /// RepoType::Model, + /// "b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(), + /// ); + /// if api + /// .is_main(&repo) + /// .await + /// .expect("api call to go through successfully") + /// { + /// println!( + /// "repo's revision ({}) is the latest on main", + /// repo.revision() + /// ); + /// } + /// # }) + /// ``` + pub async fn is_main(&self, repo: &Repo) -> Result { + let url = format!("{}/api/{}", self.endpoint, repo.api_url()); + let repo_sha: RepoSha = self + .client + .get(&url) + .query(&[("expand[]", "sha")]) + .send() + .await? + .error_for_status()? + .json() + .await?; + + Ok(repo_sha.sha == repo.revision) + } } #[cfg(test)] mod tests { use super::*; - use crate::RepoType; + use crate::{api::Siblings, RepoType}; + use hex_literal::hex; use rand::{distributions::Alphanumeric, Rng}; use sha2::{Digest, Sha256}; - use hex_literal::hex; struct TempDir { path: PathBuf, @@ -652,12 +675,13 @@ mod tests { let repo = Repo::with_revision( "wikitext".to_string(), RepoType::Dataset, - "refs/convert/parquet".to_string(), + "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(), ); let model_info = api.info(&repo).await.unwrap(); assert_eq!( model_info, - ModelInfo { + RepoInfo { + sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_owned(), siblings: vec![ Siblings { rfilename: ".gitattributes".to_string() @@ -727,4 +751,22 @@ mod tests { } ) } + + /// XXX: this test may break eventually, + /// I'll choose a repo that shouldn't receive commits + #[tokio::test] + async fn is_main() { + let tmp = TempDir::new(); + let api = ApiBuilder::new() + .with_progress(false) + .with_cache_dir(tmp.path.clone()) + .build() + .unwrap(); + let repo = Repo::with_revision( + "mcpotato/42".to_owned(), + RepoType::Model, + "b161dce5978d64da247bedd293b0c55fb4adb949".to_owned(), + ); + assert!(api.is_main(&repo).await.unwrap()); + } }