Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
name: Rust

on:
pull_request:

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest, windows-latest, macOS-latest]

steps:
- uses: actions/checkout@v3

- name: Install Rust Stable
uses: actions-rs/toolchain@v1
with:
toolchain: stable
components: rustfmt, clippy, llvm-tools-preview
override: true

- uses: Swatinem/rust-cache@v2

- name: Build
run: cargo build --all-targets --verbose

- name: Lint with Clippy
run: cargo clippy --all-targets --all-features -- -D warnings

- name: Run Tests
run: cargo test --all-features --verbose

- name: Run Audit
# RUSTSEC-2021-0145 is criterion so only within benchmarks
run: cargo audit -D warnings
15 changes: 15 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
repos:
- repo: https://github.com/Narsil/pre-commit-rust
rev: 2eed6366172ef2a5186e8785ec0e67243d7d73d0
hooks:
- id: fmt
name: "Rust (fmt)"
- id: clippy
name: "Rust (clippy)"
args:
[
"--tests",
"--examples",
"--",
"-Dwarnings",
]
19 changes: 19 additions & 0 deletions src/api/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
use serde::Deserialize;

/// The asynchronous version of the API
#[cfg(feature = "tokio")]
pub mod tokio;

/// The synchronous version of the API
pub mod sync;

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
/// The path within the repo.
pub rfilename: String,
}

/// The description of the repo given by the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct RepoInfo {
/// See [`Siblings`]
pub siblings: Vec<Siblings>,

/// The commit sha of the repo.
pub sha: String,
}
48 changes: 21 additions & 27 deletions src/api/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use std::collections::HashMap;
// redirect::Policy,
// Error as ReqwestError,
// };
use serde::Deserialize;
use super::RepoInfo;
use std::io::{Seek, SeekFrom, Write};
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
Expand Down Expand Up @@ -46,7 +46,7 @@ impl HeaderAgent {
fn get(&self, url: &str) -> ureq::Request {
let mut request = self.agent.get(url);
for (header, value) in &self.headers {
request = request.set(header, &value);
request = request.set(header, value);
}
request
}
Expand All @@ -72,7 +72,7 @@ pub enum ApiError {
// ToStr(#[from] ToStrError),
/// Error in the request
#[error("request error: {0}")]
RequestError(#[from] ureq::Error),
RequestError(#[from] Box<ureq::Error>),

/// Error parsing some range value
#[error("Cannot parse int")]
Expand All @@ -87,20 +87,6 @@ pub enum ApiError {
TooManyRetries(Box<ApiError>),
}

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
/// The path within the repo.
pub rfilename: String,
}

/// The description of the repo given by the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct ModelInfo {
/// See [`Siblings`]
pub siblings: Vec<Siblings>,
}

/// Helper to create [`Api`] with all the options.
pub struct ApiBuilder {
endpoint: String,
Expand Down Expand Up @@ -329,34 +315,36 @@ impl Api {
.no_redirect_client
.get(url)
.set(RANGE, "bytes=0-0")
.call()?;
.call()
.map_err(Box::new)?;
// let headers = response.headers();
let header_commit = "x-repo-commit";
let header_linked_etag = "x-linked-etag";
let header_etag = "etag";

let etag = match response.header(&header_linked_etag) {
let etag = match response.header(header_linked_etag) {
Some(etag) => etag,
None => response
.header(&header_etag)
.header(header_etag)
.ok_or(ApiError::MissingHeader(header_etag))?,
};
// Cleaning extra quotes
let etag = etag.to_string().replace('"', "");
let commit_hash = response
.header(&header_commit)
.header(header_commit)
.ok_or(ApiError::MissingHeader(header_commit))?
.to_string();

// The response was redirected o S3 most likely which will
// know about the size of the file
let status = response.status();
let is_redirection = status >= 300 && status < 400;
let is_redirection = (300..400).contains(&status);
let response = if is_redirection {
self.client
.get(response.header(LOCATION).unwrap())
.set(RANGE, "bytes=0-0")
.call()?
.call()
.map_err(Box::new)?
} else {
response
};
Expand Down Expand Up @@ -449,7 +437,11 @@ impl Api {
let range = format!("bytes={start}-{stop}");
let mut file = std::fs::OpenOptions::new().write(true).open(filename)?;
file.seek(SeekFrom::Start(start as u64))?;
let response = client.get(url).set(RANGE, &range).call()?;
let response = client
.get(url)
.set(RANGE, &range)
.call()
.map_err(Box::new)?;

const MAX: usize = 4096;
let mut buffer: [u8; MAX] = [0; MAX];
Expand Down Expand Up @@ -543,9 +535,9 @@ impl Api {
/// let repo = Repo::model("gpt2".to_string());
/// api.info(&repo);
/// ```
pub fn info(&self, repo: &Repo) -> Result<ModelInfo, ApiError> {
pub fn info(&self, repo: &Repo) -> Result<RepoInfo, ApiError> {
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
let response = self.client.get(&url).call()?;
let response = self.client.get(&url).call().map_err(Box::new)?;

let model_info = response.into_json()?;

Expand All @@ -556,6 +548,7 @@ impl Api {
#[cfg(test)]
mod tests {
use super::*;
use crate::api::Siblings;
use crate::RepoType;
use hex_literal::hex;
use rand::{distributions::Alphanumeric, Rng};
Expand Down Expand Up @@ -647,7 +640,8 @@ mod tests {
let model_info = api.info(&repo).unwrap();
assert_eq!(
model_info,
ModelInfo {
RepoInfo {
sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(),
siblings: vec![
Siblings {
rfilename: ".gitattributes".to_string()
Expand Down
22 changes: 5 additions & 17 deletions src/api/tokio.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use super::RepoInfo;
use crate::{Cache, Repo};
use indicatif::{ProgressBar, ProgressStyle};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
Expand All @@ -9,7 +10,6 @@ use reqwest::{
redirect::Policy,
Client, Error as ReqwestError,
};
use serde::Deserialize;
use std::num::ParseIntError;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
Expand Down Expand Up @@ -69,20 +69,6 @@ pub enum ApiError {
// InvalidResponse(Response),
}

/// Siblings are simplified file descriptions of remote files on the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct Siblings {
/// The path within the repo.
pub rfilename: String,
}

/// The description of the repo given by the hub
#[derive(Debug, Clone, Deserialize, PartialEq)]
pub struct ModelInfo {
/// See [`Siblings`]
pub siblings: Vec<Siblings>,
}

/// Helper to create [`Api`] with all the options.
pub struct ApiBuilder {
endpoint: String,
Expand Down Expand Up @@ -550,7 +536,7 @@ impl Api {
/// api.info(&repo);
/// # })
/// ```
pub async fn info(&self, repo: &Repo) -> Result<ModelInfo, ApiError> {
pub async fn info(&self, repo: &Repo) -> Result<RepoInfo, ApiError> {
let url = format!("{}/api/{}", self.endpoint, repo.api_url());
let response = self.client.get(url).send().await?;
let response = response.error_for_status()?;
Expand All @@ -564,6 +550,7 @@ impl Api {
#[cfg(test)]
mod tests {
use super::*;
use crate::api::Siblings;
use crate::RepoType;
use hex_literal::hex;
use rand::{distributions::Alphanumeric, Rng};
Expand Down Expand Up @@ -656,7 +643,8 @@ mod tests {
let model_info = api.info(&repo).await.unwrap();
assert_eq!(
model_info,
ModelInfo {
RepoInfo {
sha: "2dd3f79917d431e9af1c81bfa96a575741774077".to_string(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will eventually break when refs/convert/parquet is updated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's OK !

siblings: vec![
Siblings {
rfilename: ".gitattributes".to_string()
Expand Down