diff --git a/mithril-aggregator/src/http_server/routes/router.rs b/mithril-aggregator/src/http_server/routes/router.rs index 1b1b43e1526..43888d27622 100644 --- a/mithril-aggregator/src/http_server/routes/router.rs +++ b/mithril-aggregator/src/http_server/routes/router.rs @@ -4,11 +4,9 @@ use crate::http_server::routes::{ use crate::http_server::SERVER_BASE_PATH; use crate::DependencyManager; -use mithril_common::{ - MITHRIL_API_VERSION, MITHRIL_API_VERSION_HEADER, MITHRIL_API_VERSION_REQUIREMENT, -}; +use mithril_common::api::APIVersionProvider; +use mithril_common::MITHRIL_API_VERSION_HEADER; -use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::StatusCode; use slog_scope::warn; use std::sync::Arc; @@ -34,42 +32,62 @@ pub fn routes( .allow_any_origin() .allow_headers(vec!["content-type"]) .allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS]); - let mut headers = HeaderMap::new(); - headers.insert( - MITHRIL_API_VERSION_HEADER, - HeaderValue::from_static(MITHRIL_API_VERSION), - ); + warp::any() - .and(header_must_be()) + .and(header_must_be( + dependency_manager.clone().api_version_provider.clone(), + )) .and(warp::path(SERVER_BASE_PATH)) .and( certificate_routes::routes(dependency_manager.clone()) .or(snapshot_routes::routes(dependency_manager.clone())) .or(signer_routes::routes(dependency_manager.clone())) .or(signatures_routes::routes(dependency_manager.clone())) - .or(epoch_routes::routes(dependency_manager)) + .or(epoch_routes::routes(dependency_manager.clone())) .with(cors), ) .recover(handle_custom) - .with(warp::reply::with::headers(headers)) + .and(warp::any().map(move || dependency_manager.clone().api_version_provider.clone())) + .map(|reply, api_version_provider: Arc| { + warp::reply::with_header( + reply, + MITHRIL_API_VERSION_HEADER, + &api_version_provider + .compute_current_version() + .unwrap() + .to_string(), + ) + }) } /// API Version verification -fn header_must_be() -> impl Filter + Copy { +fn header_must_be( + api_version_provider: Arc, +) -> impl Filter + Clone { warp::header::optional(MITHRIL_API_VERSION_HEADER) - .and_then(|maybe_header: Option| async move { - match maybe_header { - None => Ok(()), - Some(version) => match semver::Version::parse(&version) { - Ok(version) if MITHRIL_API_VERSION_REQUIREMENT.matches(&version) => Ok(()), - Ok(_version) => Err(warp::reject::custom(VersionMismatchError)), - Err(err) => { - warn!("⇄ HTTP SERVER::api_version_check::parse_error"; "error" => ?err); - Err(warp::reject::custom(VersionParseError)) - } - }, - } - }) + .and(warp::any().map(move || api_version_provider.clone())) + .and_then( + move |maybe_header: Option, api_version_provider: Arc| async move { + match maybe_header { + None => Ok(()), + Some(version) => match semver::Version::parse(&version) { + Ok(version) + if (api_version_provider + .compute_current_version_requirement().unwrap() + .matches(&version)) + .to_owned() => + { + Ok(()) + } + Ok(_version) => Err(warp::reject::custom(VersionMismatchError)), + Err(err) => { + warn!("⇄ HTTP SERVER::api_version_check::parse_error"; "error" => ?err); + Err(warp::reject::custom(VersionParseError)) + } + }, + } + }, + ) .untuple_one() } @@ -83,11 +101,20 @@ pub async fn handle_custom(reject: Rejection) -> Result { #[cfg(test)] mod tests { + use std::collections::HashMap; + + use mithril_common::{ + entities::Epoch, + era::{EraChecker, SupportedEra}, + }; + use super::*; #[tokio::test] async fn test_no_version() { - let filters = header_must_be(); + let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1)); + let api_version_provider = Arc::new(APIVersionProvider::new(Arc::new(era_checker))); + let filters = header_must_be(api_version_provider); warp::test::request() .path("/aggregator/whatever") .filter(&filters) @@ -97,7 +124,9 @@ mod tests { #[tokio::test] async fn test_parse_version_error() { - let filters = header_must_be(); + let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1)); + let api_version_provider = Arc::new(APIVersionProvider::new(Arc::new(era_checker))); + let filters = header_must_be(api_version_provider); warp::test::request() .header(MITHRIL_API_VERSION_HEADER, "not_a_version") .path("/aggregator/whatever") @@ -110,7 +139,13 @@ mod tests { #[tokio::test] async fn test_bad_version() { - let filters = header_must_be(); + let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1)); + let mut version_provider = APIVersionProvider::new(Arc::new(era_checker)); + let mut open_api_versions = HashMap::new(); + open_api_versions.insert("openapi.yaml".to_string(), "1.0.0".to_string()); + version_provider.update_open_api_versions(open_api_versions); + let api_version_provider = Arc::new(version_provider); + let filters = header_must_be(api_version_provider); warp::test::request() .header(MITHRIL_API_VERSION_HEADER, "0.0.999") .path("/aggregator/whatever") @@ -121,12 +156,18 @@ mod tests { #[tokio::test] async fn test_good_version() { - let filters = header_must_be(); + let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1)); + let mut version_provider = APIVersionProvider::new(Arc::new(era_checker)); + let mut open_api_versions = HashMap::new(); + open_api_versions.insert("openapi.yaml".to_string(), "0.1.0".to_string()); + version_provider.update_open_api_versions(open_api_versions); + let api_version_provider = Arc::new(version_provider); + let filters = header_must_be(api_version_provider); warp::test::request() - .header(MITHRIL_API_VERSION_HEADER, MITHRIL_API_VERSION) + .header(MITHRIL_API_VERSION_HEADER, "0.1.2") .path("/aggregator/whatever") .filter(&filters) .await - .expect("request with the current api version should not be rejected"); + .expect(r#"request with the good version "0.1.2" should not be rejected"#); } }