Skip to content

Commit

Permalink
Feature: Better retry policy in binstalk-downloader (#794)
Browse files Browse the repository at this point in the history
Fixed #779 #791 

 - Retry request on timeout
 - Retry for `StatusCode::{REQUEST_TIMEOUT, GATEWAY_TIMEOUT}`
 - Add `DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT` for 503/429
   if 503/429 does not give us a header or give us an invalid header on
   when to retry, we would default to
   `DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT`.
 - Fix `Client::get_redirected_final_url`: Retry using `GET` on status code 400..405 + 410
 - Rename remote_exists => remote_gettable & support fallback to GET
   if HEAD fails due to status code 400..405 + 410.
 - Improve `Client::get_stream`: Include url & method in the err of the stream returned

Signed-off-by: Jiahao XU <Jiahao_XU@outlook.com>
  • Loading branch information
NobodyXu committed Feb 13, 2023
1 parent 1b2fb08 commit 87686cb
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 85 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ jobs:
runs-on: ${{ matrix.os }}
env:
CARGO_BUILD_TARGET: ${{ matrix.target }}
CARGO_BINSTALL_LOG_LEVEL: debug

steps:
- uses: actions/checkout@v3
Expand Down
7 changes: 5 additions & 2 deletions crates/bin/src/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,11 @@ pub fn logging(log_level: LevelFilter, json_output: bool) {
// Calculate log_level
let log_level = min(log_level, STATIC_MAX_LEVEL);

let allowed_targets =
(log_level != LevelFilter::Trace).then_some(["binstalk", "cargo_binstall"]);
let allowed_targets = (log_level != LevelFilter::Trace).then_some([
"binstalk",
"binstalk_downloader",
"cargo_binstall",
]);

// Forward log to tracing
Logger::init(log_level);
Expand Down
183 changes: 132 additions & 51 deletions crates/binstalk-downloader/src/remote.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
num::NonZeroU64,
num::{NonZeroU64, NonZeroU8},
ops::ControlFlow,
sync::Arc,
time::{Duration, SystemTime},
};
Expand All @@ -12,7 +13,6 @@ use reqwest::{
Request, Response, StatusCode,
};
use thiserror::Error as ThisError;
use tokio::time::Instant;
use tower::{limit::rate::RateLimit, Service, ServiceBuilder, ServiceExt};
use tracing::{debug, info};

Expand All @@ -24,6 +24,8 @@ use delay_request::DelayRequest;

const MAX_RETRY_DURATION: Duration = Duration::from_secs(120);
const MAX_RETRY_COUNT: u8 = 3;
const DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT: Duration = Duration::from_millis(200);
const RETRY_DURATION_FOR_TIMEOUT: Duration = Duration::from_millis(200);
const DEFAULT_MIN_TLS: tls::Version = tls::Version::TLS_1_2;

#[derive(Debug, ThisError)]
Expand Down Expand Up @@ -100,46 +102,88 @@ impl Client {
&self.0.client
}

async fn send_request_inner(
/// Return `Err(_)` for fatal error tht cannot be retried.
///
/// Return `Ok(ControlFlow::Continue(res))` for retryable error, `res`
/// will contain the previous `Result<Response, ReqwestError>`.
/// A retryable error could be a `ReqwestError` or `Response` with
/// unsuccessful status code.
///
/// Return `Ok(ControlFlow::Break(response))` when succeeds and no need
/// to retry.
async fn do_send_request(
&self,
method: &Method,
url: &Url,
) -> Result<Response, ReqwestError> {
let mut count = 0;
) -> Result<ControlFlow<Response, Result<Response, ReqwestError>>, ReqwestError> {
let request = Request::new(method.clone(), url.clone());

loop {
let request = Request::new(method.clone(), url.clone());
let future = (&self.0.service).ready().await?.call(request);

let future = (&self.0.service).ready().await?.call(request);
let response = match future.await {
Err(err) if err.is_timeout() => {
let duration = RETRY_DURATION_FOR_TIMEOUT;

let response = future.await?;
info!("Received timeout error from reqwest. Delay future request by {duration:#?}");

let status = response.status();
self.0.service.add_urls_to_delay(&[url], duration);

match (status, parse_header_retry_after(response.headers())) {
(
// 503 429
StatusCode::SERVICE_UNAVAILABLE | StatusCode::TOO_MANY_REQUESTS,
Some(duration),
) => {
let duration = duration.min(MAX_RETRY_DURATION);
return Ok(ControlFlow::Continue(Err(err)));
}
res => res?,
};

info!("Receiver status code {status}, will wait for {duration:#?} and retry");
let status = response.status();

let deadline = Instant::now() + duration;
let add_delay_and_continue = |response: Response, duration| {
info!("Receiver status code {status}, will wait for {duration:#?} and retry");

self.0
.service
.add_urls_to_delay(dedup([url, response.url()]), deadline);
self.0
.service
.add_urls_to_delay(&[url, response.url()], duration);

if count >= MAX_RETRY_COUNT {
break Ok(response);
}
}
_ => break Ok(response),
Ok(ControlFlow::Continue(Ok(response)))
};

match status {
// Delay further request on rate limit
StatusCode::SERVICE_UNAVAILABLE | StatusCode::TOO_MANY_REQUESTS => {
let duration = parse_header_retry_after(response.headers())
.unwrap_or(DEFAULT_RETRY_DURATION_FOR_RATE_LIMIT)
.min(MAX_RETRY_DURATION);

add_delay_and_continue(response, duration)
}

// Delay further request on timeout
StatusCode::REQUEST_TIMEOUT | StatusCode::GATEWAY_TIMEOUT => {
add_delay_and_continue(response, RETRY_DURATION_FOR_TIMEOUT)
}

_ => Ok(ControlFlow::Break(response)),
}
}

async fn send_request_inner(
&self,
method: &Method,
url: &Url,
) -> Result<Response, ReqwestError> {
let mut count = 0;
let max_retry_count = NonZeroU8::new(MAX_RETRY_COUNT).unwrap();

// Since max_retry_count is non-zero, there is at least one iteration.
loop {
// Increment the counter before checking for terminal condition.
count += 1;

match self.do_send_request(method, url).await? {
ControlFlow::Break(response) => break Ok(response),
ControlFlow::Continue(res) if count >= max_retry_count.get() => {
break res;
}
_ => (),
}
}
}

Expand All @@ -161,22 +205,57 @@ impl Client {
.map_err(|err| Error::Http(Box::new(HttpError { method, url, err })))
}

/// Check if remote exists using `method`.
pub async fn remote_exists(&self, url: Url, method: Method) -> Result<bool, Error> {
Ok(self
.send_request(method, url, false)
.await?
.status()
.is_success())
async fn head_or_fallback_to_get(
&self,
url: Url,
error_for_status: bool,
) -> Result<Response, Error> {
let res = self
.send_request(Method::HEAD, url.clone(), error_for_status)
.await;

let retry_with_get = move || async move {
// Retry using GET
info!("HEAD on {url} is not allowed, fallback to GET");
self.send_request(Method::GET, url, error_for_status).await
};

let is_retryable = |status| {
matches!(
status,
StatusCode::BAD_REQUEST // 400
| StatusCode::UNAUTHORIZED // 401
| StatusCode::FORBIDDEN // 403
| StatusCode::NOT_FOUND // 404
| StatusCode::METHOD_NOT_ALLOWED // 405
| StatusCode::GONE // 410
)
};

match res {
Err(Error::Http(http_error))
if http_error.err.status().map(is_retryable).unwrap_or(false) =>
{
retry_with_get().await
}
Ok(response) if is_retryable(response.status()) => retry_with_get().await,
res => res,
}
}

/// Attempt to get final redirected url.
/// Check if remote exists using `Method::HEAD` or `Method::GET` as fallback.
pub async fn remote_gettable(&self, url: Url) -> Result<bool, Error> {
self.head_or_fallback_to_get(url, false)
.await
.map(|response| response.status().is_success())
}

/// Attempt to get final redirected url using `Method::HEAD` or fallback
/// to `Method::GET`.
pub async fn get_redirected_final_url(&self, url: Url) -> Result<Url, Error> {
Ok(self
.send_request(Method::HEAD, url, true)
.await?
.url()
.clone())
self.head_or_fallback_to_get(url, true)
.await
.map(|response| response.url().clone())
}

/// Create `GET` request to `url` and return a stream of the response data.
Expand All @@ -187,9 +266,19 @@ impl Client {
) -> Result<impl Stream<Item = Result<Bytes, Error>>, Error> {
debug!("Downloading from: '{url}'");

self.send_request(Method::GET, url, true)
.await
.map(|response| response.bytes_stream().map(|res| res.map_err(Error::from)))
let response = self.send_request(Method::GET, url.clone(), true).await?;

let url = Box::new(url);

Ok(response.bytes_stream().map(move |res| {
res.map_err(|err| {
Error::Http(Box::new(HttpError {
method: Method::GET,
url: Url::clone(&*url),
err,
}))
})
}))
}
}

Expand Down Expand Up @@ -219,11 +308,3 @@ fn parse_header_retry_after(headers: &HeaderMap) -> Option<Duration> {
}
}
}

fn dedup(urls: [&Url; 2]) -> impl Iterator<Item = &Url> {
if urls[0] == urls[1] {
Some(urls[0]).into_iter().chain(None)
} else {
Some(urls[0]).into_iter().chain(Some(urls[1]))
}
}
68 changes: 51 additions & 17 deletions crates/binstalk-downloader/src/remote/delay_request.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
collections::HashMap,
future::Future,
iter::Peekable,
pin::Pin,
sync::Mutex,
task::{Context, Poll},
Expand All @@ -10,10 +11,41 @@ use compact_str::{CompactString, ToCompactString};
use reqwest::{Request, Url};
use tokio::{
sync::Mutex as AsyncMutex,
time::{sleep_until, Instant},
time::{sleep_until, Duration, Instant},
};
use tower::{Service, ServiceExt};

trait IterExt: Iterator {
fn dedup(self) -> Dedup<Self>
where
Self: Sized,
Self::Item: PartialEq,
{
Dedup(self.peekable())
}
}

impl<It: Iterator> IterExt for It {}

struct Dedup<It: Iterator>(Peekable<It>);

impl<It> Iterator for Dedup<It>
where
It: Iterator,
It::Item: PartialEq,
{
type Item = It::Item;

fn next(&mut self) -> Option<Self::Item> {
let curr = self.0.next()?;

// Drop all consecutive dup values
while self.0.next_if_eq(&curr).is_some() {}

Some(curr)
}
}

#[derive(Debug)]
pub(super) struct DelayRequest<S> {
inner: AsyncMutex<S>,
Expand All @@ -28,31 +60,33 @@ impl<S> DelayRequest<S> {
}
}

pub(super) fn add_urls_to_delay<'a, Urls>(&self, urls: Urls, deadline: Instant)
where
Urls: IntoIterator<Item = &'a Url>,
{
pub(super) fn add_urls_to_delay(&self, urls: &[&Url], delay_duration: Duration) {
let deadline = Instant::now() + delay_duration;

let mut hosts_to_delay = self.hosts_to_delay.lock().unwrap();

urls.into_iter().filter_map(Url::host_str).for_each(|host| {
hosts_to_delay
.entry(host.to_compact_string())
.and_modify(|old_dl| {
*old_dl = deadline.max(*old_dl);
})
.or_insert(deadline);
});
urls.iter()
.filter_map(|url| url.host_str())
.dedup()
.for_each(|host| {
hosts_to_delay
.entry(host.to_compact_string())
.and_modify(|old_dl| {
*old_dl = deadline.max(*old_dl);
})
.or_insert(deadline);
});
}

fn wait_until_available(&self, url: &Url) -> impl Future<Output = ()> + Send + 'static {
let mut hosts_to_delay = self.hosts_to_delay.lock().unwrap();

let sleep = url
let deadline = url
.host_str()
.and_then(|host| hosts_to_delay.get(host).map(|deadline| (*deadline, host)))
.and_then(|(deadline, host)| {
if deadline.elapsed().is_zero() {
Some(sleep_until(deadline))
Some(deadline)
} else {
// We have already gone past the deadline,
// so we should remove it instead.
Expand All @@ -62,8 +96,8 @@ impl<S> DelayRequest<S> {
});

async move {
if let Some(sleep) = sleep {
sleep.await;
if let Some(deadline) = deadline {
sleep_until(deadline).await;
}
}
}
Expand Down
9 changes: 4 additions & 5 deletions crates/binstalk/src/fetchers/gh_crate_meta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ use url::Url;
use crate::{
errors::{BinstallError, InvalidPkgFmtError},
helpers::{
download::Download,
futures_resolver::FuturesResolver,
remote::{Client, Method},
download::Download, futures_resolver::FuturesResolver, remote::Client,
tasks::AutoAbortJoinHandle,
},
manifests::cargo_toml_binstall::{PkgFmt, PkgMeta},
Expand Down Expand Up @@ -68,8 +66,9 @@ impl GhCrateMeta {
async move {
debug!("Checking for package at: '{url}'");

Ok((client.remote_exists(url.clone(), Method::HEAD).await?
|| client.remote_exists(url.clone(), Method::GET).await?)
Ok(client
.remote_gettable(url.clone())
.await?
.then_some((url, pkg_fmt)))
}
})
Expand Down
Loading

0 comments on commit 87686cb

Please sign in to comment.