Skip to content

Commit

Permalink
Use API version provider in aggregator router
Browse files Browse the repository at this point in the history
  • Loading branch information
jpraynaud committed Mar 17, 2023
1 parent 053f46e commit bb52009
Showing 1 changed file with 73 additions and 32 deletions.
105 changes: 73 additions & 32 deletions mithril-aggregator/src/http_server/routes/router.rs
Expand Up @@ -4,11 +4,9 @@ use crate::http_server::routes::{
use crate::http_server::SERVER_BASE_PATH;
use crate::DependencyManager;

use mithril_common::{
MITHRIL_API_VERSION, MITHRIL_API_VERSION_HEADER, MITHRIL_API_VERSION_REQUIREMENT,
};
use mithril_common::api::APIVersionProvider;
use mithril_common::MITHRIL_API_VERSION_HEADER;

use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::StatusCode;
use slog_scope::warn;
use std::sync::Arc;
Expand All @@ -34,42 +32,62 @@ pub fn routes(
.allow_any_origin()
.allow_headers(vec!["content-type"])
.allow_methods(vec![Method::GET, Method::POST, Method::OPTIONS]);
let mut headers = HeaderMap::new();
headers.insert(
MITHRIL_API_VERSION_HEADER,
HeaderValue::from_static(MITHRIL_API_VERSION),
);

warp::any()
.and(header_must_be())
.and(header_must_be(
dependency_manager.clone().api_version_provider.clone(),
))
.and(warp::path(SERVER_BASE_PATH))
.and(
certificate_routes::routes(dependency_manager.clone())
.or(snapshot_routes::routes(dependency_manager.clone()))
.or(signer_routes::routes(dependency_manager.clone()))
.or(signatures_routes::routes(dependency_manager.clone()))
.or(epoch_routes::routes(dependency_manager))
.or(epoch_routes::routes(dependency_manager.clone()))
.with(cors),
)
.recover(handle_custom)
.with(warp::reply::with::headers(headers))
.and(warp::any().map(move || dependency_manager.clone().api_version_provider.clone()))
.map(|reply, api_version_provider: Arc<APIVersionProvider>| {
warp::reply::with_header(
reply,
MITHRIL_API_VERSION_HEADER,
&api_version_provider
.compute_current_version()
.unwrap()
.to_string(),
)
})
}

/// API Version verification
fn header_must_be() -> impl Filter<Extract = (), Error = Rejection> + Copy {
fn header_must_be(
api_version_provider: Arc<APIVersionProvider>,
) -> impl Filter<Extract = (), Error = Rejection> + Clone {
warp::header::optional(MITHRIL_API_VERSION_HEADER)
.and_then(|maybe_header: Option<String>| async move {
match maybe_header {
None => Ok(()),
Some(version) => match semver::Version::parse(&version) {
Ok(version) if MITHRIL_API_VERSION_REQUIREMENT.matches(&version) => Ok(()),
Ok(_version) => Err(warp::reject::custom(VersionMismatchError)),
Err(err) => {
warn!("⇄ HTTP SERVER::api_version_check::parse_error"; "error" => ?err);
Err(warp::reject::custom(VersionParseError))
}
},
}
})
.and(warp::any().map(move || api_version_provider.clone()))
.and_then(
move |maybe_header: Option<String>, api_version_provider: Arc<APIVersionProvider>| async move {
match maybe_header {
None => Ok(()),
Some(version) => match semver::Version::parse(&version) {
Ok(version)
if (api_version_provider
.compute_current_version_requirement().unwrap()
.matches(&version))
.to_owned() =>
{
Ok(())
}
Ok(_version) => Err(warp::reject::custom(VersionMismatchError)),
Err(err) => {
warn!("⇄ HTTP SERVER::api_version_check::parse_error"; "error" => ?err);
Err(warp::reject::custom(VersionParseError))
}
},
}
},
)
.untuple_one()
}

Expand All @@ -83,11 +101,20 @@ pub async fn handle_custom(reject: Rejection) -> Result<impl Reply, Rejection> {

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use mithril_common::{
entities::Epoch,
era::{EraChecker, SupportedEra},
};

use super::*;

#[tokio::test]
async fn test_no_version() {
let filters = header_must_be();
let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
let api_version_provider = Arc::new(APIVersionProvider::new(Arc::new(era_checker)));
let filters = header_must_be(api_version_provider);
warp::test::request()
.path("/aggregator/whatever")
.filter(&filters)
Expand All @@ -97,7 +124,9 @@ mod tests {

#[tokio::test]
async fn test_parse_version_error() {
let filters = header_must_be();
let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
let api_version_provider = Arc::new(APIVersionProvider::new(Arc::new(era_checker)));
let filters = header_must_be(api_version_provider);
warp::test::request()
.header(MITHRIL_API_VERSION_HEADER, "not_a_version")
.path("/aggregator/whatever")
Expand All @@ -110,7 +139,13 @@ mod tests {

#[tokio::test]
async fn test_bad_version() {
let filters = header_must_be();
let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
let mut version_provider = APIVersionProvider::new(Arc::new(era_checker));
let mut open_api_versions = HashMap::new();
open_api_versions.insert("openapi.yaml".to_string(), "1.0.0".to_string());
version_provider.update_open_api_versions(open_api_versions);
let api_version_provider = Arc::new(version_provider);
let filters = header_must_be(api_version_provider);
warp::test::request()
.header(MITHRIL_API_VERSION_HEADER, "0.0.999")
.path("/aggregator/whatever")
Expand All @@ -121,12 +156,18 @@ mod tests {

#[tokio::test]
async fn test_good_version() {
let filters = header_must_be();
let era_checker = EraChecker::new(SupportedEra::dummy(), Epoch(1));
let mut version_provider = APIVersionProvider::new(Arc::new(era_checker));
let mut open_api_versions = HashMap::new();
open_api_versions.insert("openapi.yaml".to_string(), "0.1.0".to_string());
version_provider.update_open_api_versions(open_api_versions);
let api_version_provider = Arc::new(version_provider);
let filters = header_must_be(api_version_provider);
warp::test::request()
.header(MITHRIL_API_VERSION_HEADER, MITHRIL_API_VERSION)
.header(MITHRIL_API_VERSION_HEADER, "0.1.2")
.path("/aggregator/whatever")
.filter(&filters)
.await
.expect("request with the current api version should not be rejected");
.expect(r#"request with the good version "0.1.2" should not be rejected"#);
}
}

0 comments on commit bb52009

Please sign in to comment.