diff --git a/aggregator_api/src/models.rs b/aggregator_api/src/models.rs index 83615b8fb..ad480b8a2 100644 --- a/aggregator_api/src/models.rs +++ b/aggregator_api/src/models.rs @@ -1,7 +1,7 @@ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use janus_aggregator_core::{ datastore::models::{GlobalHpkeKeypair, HpkeKeyState}, - task::{QueryType, Task}, + task::{AggregatorTask, QueryType}, taskprov::{PeerAggregator, VerifyKeyInit}, }; use janus_core::{auth_tokens::AuthenticationToken, vdaf::VdafInstance}; @@ -125,23 +125,10 @@ pub(crate) struct TaskResp { pub(crate) aggregator_hpke_configs: Vec, } -impl TryFrom<&Task> for TaskResp { +impl TryFrom<&AggregatorTask> for TaskResp { type Error = &'static str; - fn try_from(task: &Task) -> Result { - // We have to resolve impedance mismatches between the aggregator API's view of a task - // and `aggregator_core::task::Task`. For now, we deal with this in code, but someday - // the two representations will be harmonized. - // https://github.com/divviup/janus/issues/1524 - - // Return the aggregator endpoint URL for the role opposite our own - let peer_aggregator_endpoint = match task.role() { - Role::Leader => task.helper_aggregator_endpoint(), - Role::Helper => task.leader_aggregator_endpoint(), - _ => return Err("illegal aggregator role in task"), - } - .clone(); - + fn try_from(task: &AggregatorTask) -> Result { let mut aggregator_hpke_configs: Vec<_> = task .hpke_keys() .values() @@ -151,7 +138,7 @@ impl TryFrom<&Task> for TaskResp { Ok(Self { task_id: *task.id(), - peer_aggregator_endpoint, + peer_aggregator_endpoint: task.peer_aggregator_endpoint().clone(), query_type: *task.query_type(), vdaf: task.vdaf().clone(), role: *task.role(), diff --git a/aggregator_api/src/routes.rs b/aggregator_api/src/routes.rs index 0d93586be..95d7f67b2 100644 --- a/aggregator_api/src/routes.rs +++ b/aggregator_api/src/routes.rs @@ -10,7 +10,7 @@ use crate::{ use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use janus_aggregator_core::{ datastore::{self, Datastore}, - task::Task, + task::{AggregatorTask, AggregatorTaskParameters}, taskprov::PeerAggregator, SecretBytes, }; @@ -26,7 +26,6 @@ use ring::digest::{digest, SHA256}; use std::{str::FromStr, sync::Arc, unreachable}; use trillium::{Conn, Status}; use trillium_api::{Json, State}; -use url::Url; pub(super) async fn get_config( _: &mut Conn, @@ -79,27 +78,10 @@ pub(super) async fn post_task( _: &mut Conn, (State(ds), Json(req)): (State>>, Json), ) -> Result, Error> { - // We have to resolve impedance mismatches between the aggregator API's view of a task and - // `aggregator_core::task::Task`. For now, we deal with this in code, but someday the two - // representations will be harmonized. - // https://github.com/divviup/janus/issues/1524 - if !matches!(req.role, Role::Leader | Role::Helper) { return Err(Error::BadRequest(format!("invalid role {}", req.role))); } - // struct `aggregator_core::task::Task` expects to get two aggregator endpoint URLs, but only - // the one for the peer aggregator is in the incoming request (or for that matter, is ever used - // by Janus), so we insert a fake URL for "self". - // TODO(#1524): clean this up with `aggregator_core::task::Task` changes - // unwrap safety: this fake URL is valid - let fake_aggregator_url = Url::parse("http://never-used.example.com").unwrap(); - let (leader_aggregator_endpoint, helper_aggregator_endpoint) = match req.role { - Role::Leader => (fake_aggregator_url, req.peer_aggregator_endpoint), - Role::Helper => (req.peer_aggregator_endpoint, fake_aggregator_url), - _ => unreachable!(), - }; - let vdaf_verify_key_bytes = URL_SAFE_NO_PAD .decode(&req.vdaf_verify_key) .map_err(|err| { @@ -121,7 +103,7 @@ pub(super) async fn post_task( let vdaf_verify_key = SecretBytes::new(vdaf_verify_key_bytes); - let (aggregator_auth_token, collector_auth_token) = match req.role { + let aggregator_parameters = match req.role { Role::Leader => { let aggregator_auth_token = req.aggregator_auth_token.ok_or_else(|| { Error::BadRequest( @@ -129,7 +111,11 @@ pub(super) async fn post_task( .to_string(), ) })?; - (Some(aggregator_auth_token), Some(random())) + AggregatorTaskParameters::Leader { + aggregator_auth_token, + collector_auth_token: random(), + collector_hpke_config: req.collector_hpke_config, + } } Role::Helper => { @@ -140,29 +126,21 @@ pub(super) async fn post_task( )); } - (Some(random()), None) + AggregatorTaskParameters::Helper { + aggregator_auth_token: random(), + collector_hpke_config: req.collector_hpke_config, + } } _ => unreachable!(), }; - // Unwrap safety: we always use a supported KEM. - let hpke_keys = Vec::from([generate_hpke_config_and_private_key( - random(), - HpkeKemId::X25519HkdfSha256, - HpkeKdfId::HkdfSha256, - HpkeAeadId::Aes128Gcm, - ) - .unwrap()]); - let task = Arc::new( - Task::new( + AggregatorTask::new( task_id, - leader_aggregator_endpoint, - helper_aggregator_endpoint, + /* peer_aggregator_endpoint */ req.peer_aggregator_endpoint, /* query_type */ req.query_type, /* vdaf */ req.vdaf, - /* role */ req.role, vdaf_verify_key, /* max_batch_query_count */ req.max_batch_query_count, /* task_expiration */ req.task_expiration, @@ -172,10 +150,16 @@ pub(super) async fn post_task( /* time_precision */ req.time_precision, /* tolerable_clock_skew */ Duration::from_seconds(60), // 1 minute, - /* collector_hpke_config */ req.collector_hpke_config, - aggregator_auth_token, - collector_auth_token, - hpke_keys, + // hpke_keys + // Unwrap safety: we always use a supported KEM. + [generate_hpke_config_and_private_key( + random(), + HpkeKemId::X25519HkdfSha256, + HpkeKdfId::HkdfSha256, + HpkeAeadId::Aes128Gcm, + ) + .unwrap()], + aggregator_parameters, ) .map_err(|err| Error::BadRequest(format!("Error constructing task: {err}")))?, ); @@ -183,11 +167,10 @@ pub(super) async fn post_task( ds.run_tx_with_name("post_task", |tx| { let task = Arc::clone(&task); Box::pin(async move { - if let Some(existing_task) = tx.get_task(task.id()).await? { + if let Some(existing_task) = tx.get_aggregator_task(task.id()).await? { // Check whether the existing task in the DB corresponds to the incoming task, ignoring // those fields that are randomly generated. - if existing_task.leader_aggregator_endpoint() == task.leader_aggregator_endpoint() - && existing_task.helper_aggregator_endpoint() == task.helper_aggregator_endpoint() + if existing_task.peer_aggregator_endpoint() == task.peer_aggregator_endpoint() && existing_task.query_type() == task.query_type() && existing_task.vdaf() == task.vdaf() && existing_task.opaque_vdaf_verify_key() == task.opaque_vdaf_verify_key() @@ -206,7 +189,7 @@ pub(super) async fn post_task( return Err(datastore::Error::User(err.into())); } - tx.put_task(&task).await + tx.put_aggregator_task(&task).await }) }) .await?; @@ -224,7 +207,7 @@ pub(super) async fn get_task( let task = ds .run_tx_with_name("get_task", |tx| { - Box::pin(async move { tx.get_task(&task_id).await }) + Box::pin(async move { tx.get_aggregator_task(&task_id).await }) }) .await? .ok_or(Error::NotFound)?; diff --git a/aggregator_api/src/tests.rs b/aggregator_api/src/tests.rs index c787a0c9a..5565f60c1 100644 --- a/aggregator_api/src/tests.rs +++ b/aggregator_api/src/tests.rs @@ -18,7 +18,10 @@ use janus_aggregator_core::{ test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, }, - task::{test_util::TaskBuilder, QueryType, Task}, + task::{ + test_util::NewTaskBuilder as TaskBuilder, AggregatorTask, AggregatorTaskParameters, + QueryType, + }, taskprov::test_util::PeerAggregatorBuilder, SecretBytes, }; @@ -99,13 +102,15 @@ async fn get_task_ids() { .run_tx(|tx| { Box::pin(async move { let tasks: Vec<_> = iter::repeat_with(|| { - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) + TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake) .build() + .leader_view() + .unwrap() }) .take(10) .collect(); - try_join_all(tasks.iter().map(|task| tx.put_task(task))).await?; + try_join_all(tasks.iter().map(|task| tx.put_aggregator_task(task))).await?; Ok(tasks.into_iter().map(|task| *task.id()).collect()) }) @@ -316,7 +321,7 @@ async fn post_task_helper_no_optional_fields() { let got_task = ds .run_tx(|tx| { let got_task_resp = got_task_resp.clone(); - Box::pin(async move { tx.get_task(&got_task_resp.task_id).await }) + Box::pin(async move { tx.get_aggregator_task(&got_task_resp.task_id).await }) }) .await .unwrap() @@ -324,9 +329,8 @@ async fn post_task_helper_no_optional_fields() { // Verify that the task written to the datastore matches the request... assert_eq!( - // The other aggregator endpoint in the datastore task is fake &req.peer_aggregator_endpoint, - got_task.leader_aggregator_endpoint() + got_task.peer_aggregator_endpoint() ); assert_eq!(&req.query_type, got_task.query_type()); assert_eq!(&req.vdaf, got_task.vdaf()); @@ -521,7 +525,7 @@ async fn post_task_leader_all_optional_fields() { let got_task = ds .run_tx(|tx| { let got_task_resp = got_task_resp.clone(); - Box::pin(async move { tx.get_task(&got_task_resp.task_id).await }) + Box::pin(async move { tx.get_aggregator_task(&got_task_resp.task_id).await }) }) .await .unwrap() @@ -529,9 +533,8 @@ async fn post_task_leader_all_optional_fields() { // Verify that the task written to the datastore matches the request... assert_eq!( - // The other aggregator endpoint in the datastore task is fake &req.peer_aggregator_endpoint, - got_task.helper_aggregator_endpoint() + got_task.peer_aggregator_endpoint() ); assert_eq!(&req.query_type, got_task.query_type()); assert_eq!(&req.vdaf, got_task.vdaf()); @@ -603,12 +606,15 @@ async fn get_task() { // Setup: write a task to the datastore. let (handler, _ephemeral_datastore, ds) = setup_api_test().await; - let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader).build(); + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); ds.run_tx(|tx| { let task = task.clone(); Box::pin(async move { - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; Ok(()) }) }) @@ -664,11 +670,12 @@ async fn delete_task() { let task_id = ds .run_tx(|tx| { Box::pin(async move { - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .build(); + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; Ok(*task.id()) }) @@ -739,11 +746,12 @@ async fn get_task_metrics() { let task_id = ds .run_tx(|tx| { Box::pin(async move { - let task = - TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader) - .build(); + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake) + .build() + .leader_view() + .unwrap(); let task_id = *task.id(); - tx.put_task(&task).await?; + tx.put_aggregator_task(&task).await?; let reports: Vec<_> = iter::repeat_with(|| { LeaderStoredReport::new_dummy(task_id, Time::from_seconds_since_epoch(0)) @@ -1742,9 +1750,8 @@ fn post_task_req_serialization() { #[test] fn task_resp_serialization() { - let task = Task::new( + let task = AggregatorTask::new( TaskId::from([0u8; 32]), - "https://leader.com/".parse().unwrap(), "https://helper.com/".parse().unwrap(), QueryType::FixedSize { max_batch_size: 999, @@ -1754,7 +1761,6 @@ fn task_resp_serialization() { length: 5, chunk_length: 2, }, - Role::Leader, SecretBytes::new(b"vdaf verify key!".to_vec()), 1, None, @@ -1762,21 +1768,6 @@ fn task_resp_serialization() { 100, Duration::from_seconds(3600), Duration::from_seconds(60), - HpkeConfig::new( - HpkeConfigId::from(7), - HpkeKemId::X25519HkdfSha256, - HpkeKdfId::HkdfSha256, - HpkeAeadId::Aes128Gcm, - HpkePublicKey::from([0u8; 32].to_vec()), - ), - Some( - AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw") - .unwrap(), - ), - Some( - AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw") - .unwrap(), - ), [(HpkeKeypair::new( HpkeConfig::new( HpkeConfigId::from(13), @@ -1787,6 +1778,23 @@ fn task_resp_serialization() { ), HpkePrivateKey::new(b"unused".to_vec()), ))], + AggregatorTaskParameters::Leader { + aggregator_auth_token: AuthenticationToken::new_dap_auth_token_from_string( + "Y29sbGVjdG9yLWFiY2RlZjAw", + ) + .unwrap(), + collector_auth_token: AuthenticationToken::new_dap_auth_token_from_string( + "Y29sbGVjdG9yLWFiY2RlZjAw", + ) + .unwrap(), + collector_hpke_config: HpkeConfig::new( + HpkeConfigId::from(7), + HpkeKemId::X25519HkdfSha256, + HpkeKdfId::HkdfSha256, + HpkeAeadId::Aes128Gcm, + HpkePublicKey::from([0u8; 32].to_vec()), + ), + }, ) .unwrap(); assert_tokens(