From f27f255d92f7278d18caa5fb4f8bb80d3e904d6f Mon Sep 17 00:00:00 2001 From: Vladislav Ivanov Date: Thu, 10 Nov 2022 16:38:09 +0700 Subject: [PATCH] Refactor: use a sum type instead of a bool to indicate need for auth This makes it less confusing and more readable Change: I91c0784beda86e9d66bf62c77b0e568a3811432c commit-id:8b1d1b05 --- josh-proxy/src/auth.rs | 1 + josh-proxy/src/bin/josh-proxy.rs | 94 ++++++++++++++++++-------------- josh-proxy/src/lib.rs | 39 +++++++++++-- 3 files changed, 87 insertions(+), 47 deletions(-) diff --git a/josh-proxy/src/auth.rs b/josh-proxy/src/auth.rs index 2d4d77a66..ffd44e4f8 100644 --- a/josh-proxy/src/auth.rs +++ b/josh-proxy/src/auth.rs @@ -26,6 +26,7 @@ impl std::fmt::Debug for Handle { } impl Handle { + // Returns a pair: (username, password) pub fn parse(&self) -> josh::JoshResult<(String, String)> { let line = josh::some_or!( AUTH.lock() diff --git a/josh-proxy/src/bin/josh-proxy.rs b/josh-proxy/src/bin/josh-proxy.rs index 510f48783..80b59b47b 100644 --- a/josh-proxy/src/bin/josh-proxy.rs +++ b/josh-proxy/src/bin/josh-proxy.rs @@ -2,7 +2,7 @@ #[macro_use] extern crate lazy_static; -use josh_proxy::{MetaConfig, RepoConfig, RepoUpdate}; +use josh_proxy::{FetchError, MetaConfig, RepoConfig, RepoUpdate}; use opentelemetry::global; use opentelemetry::sdk::propagation::TraceContextPropagator; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -77,7 +77,7 @@ async fn fetch_upstream( remote_url: String, headref: &str, force: bool, -) -> josh::JoshResult { +) -> Result<(), FetchError> { let auth = auth.clone(); let key = remote_url.clone(); @@ -118,7 +118,7 @@ async fn fetch_upstream( tracing::trace!("fetch_cached_ok {:?}", fetch_cached_ok); if fetch_cached_ok && headref.is_empty() { - return Ok(true); + return Ok(()); } if fetch_cached_ok && !headref.is_empty() { @@ -128,13 +128,14 @@ async fn fetch_upstream( "refs/josh/upstream/{}/", &josh::to_ns(&upstream_repo), )), - )?; + ) + .map_err(FetchError::from_josh_error)?; let id = transaction .repo() .refname_to_id(&transaction.refname(headref)); tracing::trace!("refname_to_id: {:?}", id); if id.is_ok() { - return Ok(true); + return Ok(()); } } @@ -142,13 +143,13 @@ async fn fetch_upstream( let heads_map = service.heads_map.clone(); let br_path = service.repo_path.join("mirror"); - let s = tracing::span!(tracing::Level::TRACE, "fetch worker"); + let span = tracing::span!(tracing::Level::TRACE, "fetch worker"); let us = upstream_repo.clone(); let a = auth.clone(); let ru = remote_url.clone(); let permit = service.fetch_permits.acquire().await; - let res = tokio::task::spawn_blocking(move || { - let _e = s.enter(); + let fetch_result = tokio::task::spawn_blocking(move || { + let _span_guard = span.enter(); josh_proxy::fetch_refs_from_url(&br_path, &us, &ru, &refs_to_fetch, &a) }) .await?; @@ -170,20 +171,24 @@ async fn fetch_upstream( std::mem::drop(permit); - if let Ok(res) = res { - if res { + match fetch_result { + Ok(_) => { fetch_timers.write()?.insert(key, std::time::Instant::now()); - if ARGS.get_one::("poll").map(|v| v.as_str()) == Some(&auth.parse()?.0) { + let poll_user = ARGS.get_one::("poll"); + let (auth_user, _) = auth.parse().map_err(FetchError::from_josh_error)?; + + if matches!(poll_user, Some(user) if auth_user == user.as_str()) { service .poll .lock()? .insert((upstream_repo, auth, remote_url)); } + + Ok(()) } - return Ok(res); + Err(_) => fetch_result, } - res } async fn static_paths( @@ -422,8 +427,9 @@ async fn query_meta_repo( .in_current_span() .await { - Ok(true) => {} - _ => return Err(josh::josh_error("meta fetch failed")), + Ok(_) => {} + Err(FetchError::AuthRequired) => return Err(josh_error("meta fetch: auth failed")), + Err(FetchError::Other(e)) => return Err(josh_error(&format!("meta fetch failed: {}", e))), } let transaction = josh::cache::Transaction::open( @@ -865,20 +871,19 @@ async fn call_service( .in_current_span() .await { - Ok(res) => { - if !res { - let builder = Response::builder() - .header( - hyper::header::WWW_AUTHENTICATE, - "Basic realm=User Visible Realm", - ) - .status(hyper::StatusCode::UNAUTHORIZED); - return Ok(builder.body(hyper::Body::empty())?); - } + Ok(_) => {} + Err(FetchError::AuthRequired) => { + let builder = Response::builder() + .header( + hyper::header::WWW_AUTHENTICATE, + "Basic realm=User Visible Realm", + ) + .status(hyper::StatusCode::UNAUTHORIZED); + return Ok(builder.body(hyper::Body::empty())?); } - Err(res) => { + Err(FetchError::Other(e)) => { let builder = Response::builder().status(hyper::StatusCode::INTERNAL_SERVER_ERROR); - return Ok(builder.body(hyper::Body::from(res.0))?); + return Ok(builder.body(hyper::Body::from(e.0))?); } } @@ -1175,7 +1180,7 @@ async fn run_polling(serv: Arc) -> josh::JoshResult<()> { let polls = serv.poll.lock()?.clone(); for (upstream_repo, auth, url) in polls { - fetch_upstream( + let fetch_result = fetch_upstream( serv.clone(), upstream_repo.clone(), &auth, @@ -1184,7 +1189,15 @@ async fn run_polling(serv: Arc) -> josh::JoshResult<()> { true, ) .in_current_span() - .await?; + .await; + + match fetch_result { + Ok(()) => {} + Err(FetchError::Other(e)) => return Err(e), + Err(FetchError::AuthRequired) => { + return Err(josh_error("auth: access denied while polling")) + } + } } tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; } @@ -1383,21 +1396,20 @@ async fn serve_graphql( .in_current_span() .await { - Ok(res) => { - if !res { - let builder = Response::builder() - .header( - hyper::header::WWW_AUTHENTICATE, - "Basic realm=User Visible Realm", - ) - .status(hyper::StatusCode::UNAUTHORIZED); - return Ok(builder.body(hyper::Body::empty())?); - } + Ok(_) => {} + Err(FetchError::AuthRequired) => { + let builder = Response::builder() + .header( + hyper::header::WWW_AUTHENTICATE, + "Basic realm=User Visible Realm", + ) + .status(hyper::StatusCode::UNAUTHORIZED); + return Ok(builder.body(hyper::Body::empty())?); } - Err(res) => { + Err(FetchError::Other(e)) => { let builder = Response::builder().status(hyper::StatusCode::INTERNAL_SERVER_ERROR); - return Ok(builder.body(hyper::Body::from(res.0))?); + return Ok(builder.body(hyper::Body::from(e.0))?); } }; diff --git a/josh-proxy/src/lib.rs b/josh-proxy/src/lib.rs index 9d00fcbff..08344e767 100644 --- a/josh-proxy/src/lib.rs +++ b/josh-proxy/src/lib.rs @@ -4,6 +4,7 @@ pub mod juniper_hyper; #[macro_use] extern crate lazy_static; +use josh::JoshError; use std::path::PathBuf; #[derive(PartialEq)] @@ -546,13 +547,33 @@ pub fn get_head( Ok(head) } +pub enum FetchError { + AuthRequired, + Other(JoshError), +} + +impl From for FetchError +where + T: std::error::Error, +{ + fn from(e: T) -> Self { + FetchError::Other(JoshError::from(e)) + } +} + +impl FetchError { + pub fn from_josh_error(e: JoshError) -> Self { + FetchError::Other(e) + } +} + pub fn fetch_refs_from_url( path: &std::path::Path, upstream_repo: &str, url: &str, refs_prefixes: &[String], auth: &auth::Handle, -) -> josh::JoshResult { +) -> Result<(), FetchError> { let specs: Vec<_> = refs_prefixes .iter() .map(|r| { @@ -572,7 +593,7 @@ pub fn fetch_refs_from_url( let cmd = format!("git fetch --prune --no-tags {} {}", &url, &specs.join(" ")); tracing::info!("fetch_refs_from_url {:?} {:?} {:?}", cmd, path, ""); - let (username, password) = auth.parse()?; + let (username, password) = auth.parse().map_err(FetchError::from_josh_error)?; let (_stdout, stderr, _) = shell.command_env( &cmd, &[], @@ -580,17 +601,23 @@ pub fn fetch_refs_from_url( ); tracing::debug!("fetch_refs_from_url done {:?} {:?} {:?}", cmd, path, stderr); if stderr.contains("fatal: Authentication failed") { - return Ok(false); + return Err(FetchError::AuthRequired); } if stderr.contains("fatal:") { tracing::error!("{:?}", stderr); - return Err(josh::josh_error(&format!("git error: {:?}", stderr))); + return Err(FetchError::Other(josh::josh_error(&format!( + "git error: {:?}", + stderr + )))); } if stderr.contains("error:") { tracing::error!("{:?}", stderr); - return Err(josh::josh_error(&format!("git error: {:?}", stderr))); + return Err(FetchError::Other( + josh::josh_error(&format!("git error: {:?}", stderr)).into(), + )); } - Ok(true) + + Ok(()) } pub struct TmpGitNamespace {