diff --git a/aggregator_api/src/lib.rs b/aggregator_api/src/lib.rs index 3b27f9fbc..28f82bfff 100644 --- a/aggregator_api/src/lib.rs +++ b/aggregator_api/src/lib.rs @@ -16,9 +16,10 @@ use janus_core::{ time::Clock, }; use janus_messages::{ - query_type::Code as QueryTypeId, Duration, HpkeAeadId, HpkeKdfId, HpkeKemId, Role, TaskId, + query_type::Code as SupportedQueryType, Duration, HpkeAeadId, HpkeKdfId, HpkeKemId, Role, + TaskId, }; -use models::{AggregatorApiConfig, AggregatorRole, GetTaskMetricsResp, TaskResp, VdafId}; +use models::{AggregatorApiConfig, AggregatorRole, GetTaskMetricsResp, SupportedVdaf, TaskResp}; use querystring::querify; use rand::random; use ring::{ @@ -105,32 +106,39 @@ pub fn aggregator_api_handler(ds: Arc>, cfg: Config) -> i ) } -async fn auth_check(conn: &mut Conn, State(cfg): State>) -> impl Handler { - let bearer_token = match extract_bearer_token(conn) { - Ok(Some(t)) => t, - _ => { - return Some((Status::Unauthorized, Halt)); - } +async fn auth_check(conn: &mut Conn, (): ()) -> impl Handler { + let (Some(cfg), Ok(Some(bearer_token))) = + (conn.state::>(), extract_bearer_token(conn)) + else { + return Some((Status::Unauthorized, Halt)); }; if cfg.auth_tokens.iter().any(|key| { constant_time::verify_slices_are_equal(bearer_token.as_ref(), key.as_ref()).is_ok() }) { // Authorization succeeds. - conn.set_state(cfg); - return None; + None + } else { + // Authorization fails. + Some((Status::Unauthorized, Halt)) } - - // Authorization fails. - Some((Status::Unauthorized, Halt)) } async fn get_config(_: &mut Conn, State(config): State>) -> Json { Json(AggregatorApiConfig { dap_url: config.public_dap_url.clone(), role: AggregatorRole::Either, - vdafs: vec![VdafId::Prio3Count, VdafId::Prio3Sum, VdafId::Prio3Histogram], - query_types: vec![QueryTypeId::TimeInterval, QueryTypeId::FixedSize], + vdafs: vec![ + SupportedVdaf::Prio3Count, + SupportedVdaf::Prio3Sum, + SupportedVdaf::Prio3Histogram, + SupportedVdaf::Prio3CountVec, + SupportedVdaf::Prio3SumVec, + ], + query_types: vec![ + SupportedQueryType::TimeInterval, + SupportedQueryType::FixedSize, + ], }) } @@ -420,7 +428,7 @@ mod models { use janus_aggregator_core::task::{QueryType, Task}; use janus_core::task::VdafInstance; use janus_messages::{ - query_type::Code as QueryTypeId, Duration, HpkeConfig, Role, TaskId, Time, + query_type::Code as SupportedQueryType, Duration, HpkeConfig, Role, TaskId, Time, }; use serde::{Deserialize, Serialize}; use url::Url; @@ -438,18 +446,19 @@ mod models { pub(crate) struct AggregatorApiConfig { pub dap_url: Url, pub role: AggregatorRole, - pub vdafs: Vec, - pub query_types: Vec, + pub vdafs: Vec, + pub query_types: Vec, } #[allow(clippy::enum_variant_names)] // ^^ allowed because it just happens to be the case that all of the supported vdafs are prio3 #[derive(Serialize, PartialEq, Eq, Debug)] - #[repr(u8)] - pub(crate) enum VdafId { - Prio3Count = 0, - Prio3Sum = 1, - Prio3Histogram = 2, + pub(crate) enum SupportedVdaf { + Prio3Count, + Prio3Sum, + Prio3Histogram, + Prio3SumVec, + Prio3CountVec, } #[derive(Serialize)] @@ -706,7 +715,10 @@ mod tests { .run_async(&handler) .await, Status::Ok, - r#"{"dap_url":"https://dap.url/","role":"Either","vdafs":[1,2,3],"query_types":[1,2]}"# + concat!( + r#"{"dap_url":"https://dap.url/","role":"Either","vdafs":["Prio3Count","#, + r#""Prio3Sum","Prio3Histogram"],"query_types":["TimeInterval","FixedSize"]}"# + ) ); } diff --git a/aggregator_core/src/lib.rs b/aggregator_core/src/lib.rs index adbd575ad..8e9179b22 100644 --- a/aggregator_core/src/lib.rs +++ b/aggregator_core/src/lib.rs @@ -68,8 +68,8 @@ impl InstrumentedHandler { self.0.run(conn).instrument(span).await } - async fn before_send(&self, conn: Conn) -> Conn { - if let Some(span) = conn.state::() { + async fn before_send(&self, mut conn: Conn) -> Conn { + if let Some(span) = conn.take_state::() { let conn = self.0.before_send(conn).instrument(span.0.clone()).await; span.0.in_scope(|| { let status = conn