From 640b9ba30c43b90454df85ef2f42aeda4e244d43 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jul 2023 17:21:15 +0200 Subject: [PATCH 1/2] Move `ModelInfo` up and rename as `RepoInfo`. --- .github/workflows/rust.yml | 36 ++++++++++++++++++++++++++++++++++++ src/api/mod.rs | 19 +++++++++++++++++++ src/api/sync.rs | 22 +++++----------------- src/api/tokio.rs | 22 +++++----------------- 4 files changed, 65 insertions(+), 34 deletions(-) create mode 100644 .github/workflows/rust.yml diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..b42ebb1 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,36 @@ +name: Rust + +on: + pull_request: + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, windows-latest, macOS-latest] + + steps: + - uses: actions/checkout@v3 + + - name: Install Rust Stable + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + components: rustfmt, clippy, llvm-tools-preview + override: true + + - uses: Swatinem/rust-cache@v2 + + - name: Build + run: cargo build --all-targets --verbose + + - name: Lint with Clippy + run: cargo clippy --all-targets --all-features -- -D warnings + + - name: Run Tests + run: cargo test --all-features --verbose + + - name: Run Audit + # RUSTSEC-2021-0145 is criterion so only within benchmarks + run: cargo audit -D warnings diff --git a/src/api/mod.rs b/src/api/mod.rs index 779dc4f..3d1ed78 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -1,6 +1,25 @@ +use serde::Deserialize; + /// The asynchronous version of the API #[cfg(feature = "tokio")] pub mod tokio; /// The synchronous version of the API pub mod sync; + +/// Siblings are simplified file descriptions of remote files on the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct Siblings { + /// The path within the repo. + pub rfilename: String, +} + +/// The description of the repo given by the hub +#[derive(Debug, Clone, Deserialize, PartialEq)] +pub struct RepoInfo { + /// See [`Siblings`] + pub siblings: Vec, + + /// The commit sha of the repo. + pub sha: String, +} diff --git a/src/api/sync.rs b/src/api/sync.rs index 14e6910..78095db 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; // redirect::Policy, // Error as ReqwestError, // }; -use serde::Deserialize; +use super::RepoInfo; use std::io::{Seek, SeekFrom, Write}; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; @@ -87,20 +87,6 @@ pub enum ApiError { TooManyRetries(Box), } -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - /// Helper to create [`Api`] with all the options. pub struct ApiBuilder { endpoint: String, @@ -543,7 +529,7 @@ impl Api { /// let repo = Repo::model("gpt2".to_string()); /// api.info(&repo); /// ``` - pub fn info(&self, repo: &Repo) -> Result { + pub fn info(&self, repo: &Repo) -> Result { let url = format!("{}/api/{}", self.endpoint, repo.api_url()); let response = self.client.get(&url).call()?; @@ -556,6 +542,7 @@ impl Api { #[cfg(test)] mod tests { use super::*; + use crate::api::Siblings; use crate::RepoType; use hex_literal::hex; use rand::{distributions::Alphanumeric, Rng}; @@ -647,7 +634,8 @@ mod tests { let model_info = api.info(&repo).unwrap(); assert_eq!( model_info, - ModelInfo { + RepoInfo { + sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(), siblings: vec![ Siblings { rfilename: ".gitattributes".to_string() diff --git a/src/api/tokio.rs b/src/api/tokio.rs index c1aac74..9c233ef 100644 --- a/src/api/tokio.rs +++ b/src/api/tokio.rs @@ -1,3 +1,4 @@ +use super::RepoInfo; use crate::{Cache, Repo}; use indicatif::{ProgressBar, ProgressStyle}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; @@ -9,7 +10,6 @@ use reqwest::{ redirect::Policy, Client, Error as ReqwestError, }; -use serde::Deserialize; use std::num::ParseIntError; use std::path::{Component, Path, PathBuf}; use std::sync::Arc; @@ -69,20 +69,6 @@ pub enum ApiError { // InvalidResponse(Response), } -/// Siblings are simplified file descriptions of remote files on the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct Siblings { - /// The path within the repo. - pub rfilename: String, -} - -/// The description of the repo given by the hub -#[derive(Debug, Clone, Deserialize, PartialEq)] -pub struct ModelInfo { - /// See [`Siblings`] - pub siblings: Vec, -} - /// Helper to create [`Api`] with all the options. pub struct ApiBuilder { endpoint: String, @@ -550,7 +536,7 @@ impl Api { /// api.info(&repo); /// # }) /// ``` - pub async fn info(&self, repo: &Repo) -> Result { + pub async fn info(&self, repo: &Repo) -> Result { let url = format!("{}/api/{}", self.endpoint, repo.api_url()); let response = self.client.get(url).send().await?; let response = response.error_for_status()?; @@ -564,6 +550,7 @@ impl Api { #[cfg(test)] mod tests { use super::*; + use crate::api::Siblings; use crate::RepoType; use hex_literal::hex; use rand::{distributions::Alphanumeric, Rng}; @@ -656,7 +643,8 @@ mod tests { let model_info = api.info(&repo).await.unwrap(); assert_eq!( model_info, - ModelInfo { + RepoInfo { + sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(), siblings: vec![ Siblings { rfilename: ".gitattributes".to_string() From fdf731662c425819dfeed810dce3db5fa6eb71c1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 27 Jul 2023 17:32:03 +0200 Subject: [PATCH 2/2] Fixing clippy. Co-Author: Luc Georges --- .pre-commit-config.yaml | 15 +++++++++++++++ src/api/sync.rs | 26 ++++++++++++++++---------- 2 files changed, 31 insertions(+), 10 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..83e3e68 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: https://github.com/Narsil/pre-commit-rust + rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0 + hooks: + - id: fmt + name: "Rust (fmt)" + - id: clippy + name: "Rust (clippy)" + args: + [ + "--tests", + "--examples", + "--", + "-Dwarnings", + ] diff --git a/src/api/sync.rs b/src/api/sync.rs index 78095db..50909eb 100644 --- a/src/api/sync.rs +++ b/src/api/sync.rs @@ -46,7 +46,7 @@ impl HeaderAgent { fn get(&self, url: &str) -> ureq::Request { let mut request = self.agent.get(url); for (header, value) in &self.headers { - request = request.set(header, &value); + request = request.set(header, value); } request } @@ -72,7 +72,7 @@ pub enum ApiError { // ToStr(#[from] ToStrError), /// Error in the request #[error("request error: {0}")] - RequestError(#[from] ureq::Error), + RequestError(#[from] Box), /// Error parsing some range value #[error("Cannot parse int")] @@ -315,34 +315,36 @@ impl Api { .no_redirect_client .get(url) .set(RANGE, "bytes=0-0") - .call()?; + .call() + .map_err(Box::new)?; // let headers = response.headers(); let header_commit = "x-repo-commit"; let header_linked_etag = "x-linked-etag"; let header_etag = "etag"; - let etag = match response.header(&header_linked_etag) { + let etag = match response.header(header_linked_etag) { Some(etag) => etag, None => response - .header(&header_etag) + .header(header_etag) .ok_or(ApiError::MissingHeader(header_etag))?, }; // Cleaning extra quotes let etag = etag.to_string().replace('"', ""); let commit_hash = response - .header(&header_commit) + .header(header_commit) .ok_or(ApiError::MissingHeader(header_commit))? .to_string(); // The response was redirected o S3 most likely which will // know about the size of the file let status = response.status(); - let is_redirection = status >= 300 && status < 400; + let is_redirection = (300..400).contains(&status); let response = if is_redirection { self.client .get(response.header(LOCATION).unwrap()) .set(RANGE, "bytes=0-0") - .call()? + .call() + .map_err(Box::new)? } else { response }; @@ -435,7 +437,11 @@ impl Api { let range = format!("bytes={start}-{stop}"); let mut file = std::fs::OpenOptions::new().write(true).open(filename)?; file.seek(SeekFrom::Start(start as u64))?; - let response = client.get(url).set(RANGE, &range).call()?; + let response = client + .get(url) + .set(RANGE, &range) + .call() + .map_err(Box::new)?; const MAX: usize = 4096; let mut buffer: [u8; MAX] = [0; MAX]; @@ -531,7 +537,7 @@ impl Api { /// ``` pub fn info(&self, repo: &Repo) -> Result { let url = format!("{}/api/{}", self.endpoint, repo.api_url()); - let response = self.client.get(&url).call()?; + let response = self.client.get(&url).call().map_err(Box::new)?; let model_info = response.into_json()?;