diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 0e59a9aea..bc9e841c6 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -33,7 +33,7 @@ use janus_aggregator_core::{ Datastore, Error as DatastoreError, Transaction, }, query_type::AccumulableQueryType, - task::{self, AggregatorTask, Task, VerifyKey}, + task::{self, AggregatorTask, VerifyKey}, taskprov::PeerAggregator, }; #[cfg(feature = "test-util")] @@ -615,7 +615,7 @@ impl Aggregator { .datastore .run_tx_with_name("task_aggregator_get_task", |tx| { let task_id = *task_id; - Box::pin(async move { tx.get_task(&task_id).await }) + Box::pin(async move { tx.get_aggregator_task(&task_id).await }) }) .await? { @@ -774,7 +774,7 @@ impl Aggregator { // Aggregate requests) using a parallelized library like Rayon. pub struct TaskAggregator { /// The task being aggregated. - task: Arc, + task: Arc, /// VDAF-specific operations. vdaf_ops: VdafOps, /// Report writer, with support for batching. @@ -784,7 +784,7 @@ pub struct TaskAggregator { impl TaskAggregator { /// Create a new aggregator. `report_recipient` is used to decrypt reports received by this /// aggregator. - fn new(task: Task, report_writer: Arc>) -> Result { + fn new(task: AggregatorTask, report_writer: Arc>) -> Result { let vdaf_ops = match task.vdaf() { VdafInstance::Prio3Count => { let vdaf = Prio3::new_count(2)?; @@ -1172,7 +1172,7 @@ impl VdafOps { global_hpke_keypairs: &GlobalHpkeKeypairCache, upload_decrypt_failure_counter: &Counter, upload_decode_failure_counter: &Counter, - task: &Task, + task: &AggregatorTask, report_writer: &ReportWriteBatcher, report: Report, ) -> Result<(), Arc> { @@ -1222,7 +1222,7 @@ impl VdafOps { datastore: &Datastore, global_hpke_keypairs: &GlobalHpkeKeypairCache, aggregate_step_failure_counter: &Counter, - task: Arc, + task: Arc, batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, req_bytes: &[u8], @@ -1272,7 +1272,7 @@ impl VdafOps { &self, datastore: &Datastore, aggregate_step_failure_counter: &Counter, - task: Arc, + task: Arc, batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, req: Arc, @@ -1318,7 +1318,7 @@ impl VdafOps { global_hpke_keypairs: &GlobalHpkeKeypairCache, upload_decrypt_failure_counter: &Counter, upload_decode_failure_counter: &Counter, - task: &Task, + task: &AggregatorTask, report_writer: &ReportWriteBatcher, report: Report, ) -> Result<(), Arc> @@ -1515,7 +1515,7 @@ impl VdafOps { /// the sense that no new rows would need to be written to service the job. async fn check_aggregation_job_idempotence<'b, const SEED_SIZE: usize, Q, A, C>( tx: &Transaction<'b, C>, - task: &Task, + task: &AggregatorTask, incoming_aggregation_job: &AggregationJob, ) -> Result where @@ -1553,7 +1553,7 @@ impl VdafOps { global_hpke_keypairs: &GlobalHpkeKeypairCache, vdaf: &A, aggregate_step_failure_counter: &Counter, - task: Arc, + task: Arc, batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, verify_key: &VerifyKey, @@ -1595,7 +1595,7 @@ impl VdafOps { let agg_param = A::AggregationParam::get_decoded(req.aggregation_parameter())?; let mut accumulator = Accumulator::::new( - Arc::new(task.view_for_role()?), + Arc::clone(&task), batch_aggregation_shard_count, agg_param.clone(), ); @@ -1603,7 +1603,7 @@ impl VdafOps { for (ord, prepare_init) in req.prepare_inits().iter().enumerate() { // Compute intervals for each batch identifier included in this aggregation job. let batch_identifier = Q::to_batch_identifier( - &task.view_for_role()?, + &task, req.batch_selector().batch_identifier(), prepare_init.report_share().metadata().time(), )?; @@ -2043,7 +2043,7 @@ impl VdafOps { datastore: &Datastore, vdaf: Arc, aggregate_step_failure_counter: &Counter, - task: Arc, + task: Arc, batch_aggregation_shard_count: u64, aggregation_job_id: &AggregationJobId, leader_aggregation_job: Arc, @@ -2065,8 +2065,6 @@ impl VdafOps { )); } - let task = Arc::new(task.view_for_role()?); - // TODO(#224): don't hold DB transaction open while computing VDAF updates? // TODO(#224): don't do O(n) network round-trips (where n is the number of prepare steps) Ok(datastore @@ -2177,7 +2175,7 @@ impl VdafOps { async fn handle_create_collection_job( &self, datastore: &Datastore, - task: Arc, + task: Arc, collection_job_id: &CollectionJobId, collection_req_bytes: &[u8], ) -> Result<(), Error> { @@ -2214,7 +2212,7 @@ impl VdafOps { C: Clock, >( datastore: &Datastore, - task: Arc, + task: Arc, vdaf: Arc, collection_job_id: &CollectionJobId, req_bytes: &[u8], @@ -2263,9 +2261,8 @@ impl VdafOps { } } - let aggregator_task = task.view_for_role()?; let collection_identifier = - Q::collection_identifier_for_query(tx, &aggregator_task, req.query()) + Q::collection_identifier_for_query(tx, &task, req.query()) .await? .ok_or_else(|| { datastore::Error::User( @@ -2279,8 +2276,7 @@ impl VdafOps { // Check that the batch interval is valid for the task // https://www.ietf.org/archive/id/draft-ietf-ppm-dap-02.html#section-4.5.6.1.1 - if !Q::validate_collection_identifier(&aggregator_task, &collection_identifier) - { + if !Q::validate_collection_identifier(&task, &collection_identifier) { return Err(datastore::Error::User( Error::BatchInvalid(*task.id(), format!("{collection_identifier}")) .into(), @@ -2292,14 +2288,14 @@ impl VdafOps { Q::validate_query_count::( tx, &vdaf, - &aggregator_task, + &task, &collection_identifier, &aggregation_param, ), - Q::count_client_reports(tx, &aggregator_task, &collection_identifier), + Q::count_client_reports(tx, &task, &collection_identifier), try_join_all( Q::batch_identifiers_for_collection_identifier( - &aggregator_task, + &task, &collection_identifier ) .map(|batch_identifier| { @@ -2319,7 +2315,7 @@ impl VdafOps { ), try_join_all( Q::batch_identifiers_for_collection_identifier( - &aggregator_task, + &task, &collection_identifier ) .map(|batch_identifier| { @@ -2480,7 +2476,7 @@ impl VdafOps { async fn handle_get_collection_job( &self, datastore: &Datastore, - task: Arc, + task: Arc, collection_job_id: &CollectionJobId, ) -> Result>, Error> { match task.query_type() { @@ -2517,7 +2513,7 @@ impl VdafOps { C: Clock, >( datastore: &Datastore, - task: Arc, + task: Arc, vdaf: Arc, collection_job_id: &CollectionJobId, ) -> Result>, Error> @@ -2539,11 +2535,10 @@ impl VdafOps { ) })?; - let aggregator_task = task.view_for_role()?; let (batches, _) = try_join!( Q::get_batches_for_collection_identifier( tx, - &aggregator_task, + &task, collection_job.batch_identifier(), collection_job.aggregation_parameter() ), @@ -2657,7 +2652,7 @@ impl VdafOps { async fn handle_delete_collection_job( &self, datastore: &Datastore, - task: Arc, + task: Arc, collection_job_id: &CollectionJobId, ) -> Result<(), Error> { match task.query_type() { @@ -2693,7 +2688,7 @@ impl VdafOps { C: Clock, >( datastore: &Datastore, - task: Arc, + task: Arc, vdaf: Arc, collection_job_id: &CollectionJobId, ) -> Result<(), Error> @@ -2743,7 +2738,7 @@ impl VdafOps { &self, datastore: &Datastore, clock: &C, - task: Arc, + task: Arc, batch_aggregation_shard_count: u64, req_bytes: &[u8], collector_hpke_config: &HpkeConfig, @@ -2782,7 +2777,7 @@ impl VdafOps { >( datastore: &Datastore, clock: &C, - task: Arc, + task: Arc, vdaf: Arc, req_bytes: &[u8], batch_aggregation_shard_count: u64, @@ -2798,7 +2793,7 @@ impl VdafOps { // §4.4.4.3: check that the batch interval meets the requirements from §4.6 if !Q::validate_collection_identifier( - &task.view_for_role()?, + &task, aggregate_share_req.batch_selector().batch_identifier(), ) { return Err(Error::BatchInvalid( @@ -2834,7 +2829,6 @@ impl VdafOps { Arc::clone(&aggregate_share_req), ); Box::pin(async move { - let aggregator_task = task.view_for_role()?; // Check if we have already serviced an aggregate share request with these // parameters and serve the cached results if so. let aggregation_param = A::AggregationParam::get_decoded( @@ -2867,7 +2861,7 @@ impl VdafOps { let (batch_aggregations, _) = try_join!( Q::get_batch_aggregations_for_collection_identifier( tx, - &aggregator_task, + &task, vdaf.as_ref(), aggregate_share_req.batch_selector().batch_identifier(), &aggregation_param @@ -2875,7 +2869,7 @@ impl VdafOps { Q::validate_query_count::( tx, vdaf.as_ref(), - &aggregator_task, + &task, aggregate_share_req.batch_selector().batch_identifier(), &aggregation_param, ) @@ -2885,7 +2879,7 @@ impl VdafOps { // currently-nonexistent batch aggregation, we write (empty) batch // aggregations for any that have not already been written to storage. let empty_batch_aggregations = empty_batch_aggregations( - &aggregator_task, + &task, batch_aggregation_shard_count, aggregate_share_req.batch_selector().batch_identifier(), &aggregation_param, @@ -2894,7 +2888,7 @@ impl VdafOps { let (helper_aggregate_share, report_count, checksum) = compute_aggregate_share::( - &task.view_for_role()?, + &task, &batch_aggregations, ) .await @@ -3115,7 +3109,10 @@ mod tests { test_util::{ephemeral_datastore, EphemeralDatastore}, Datastore, }, - task::{test_util::TaskBuilder, QueryType, Task}, + task::{ + test_util::{NewTaskBuilder as TaskBuilder, Task}, + AggregatorTask, QueryType, + }, test_util::noop_meter, }; use janus_core::{ @@ -3153,7 +3150,7 @@ mod tests { } pub(super) fn create_report_custom( - task: &Task, + task: &AggregatorTask, report_timestamp: Time, id: ReportId, hpke_key: &HpkeKeypair, @@ -3194,7 +3191,7 @@ mod tests { ) } - pub(super) fn create_report(task: &Task, report_timestamp: Time) -> Report { + pub(super) fn create_report(task: &AggregatorTask, report_timestamp: Time) -> Report { create_report_custom(task, report_timestamp, random(), task.current_hpke_key()) } @@ -3210,17 +3207,14 @@ mod tests { ) { let clock = MockClock::default(); let vdaf = Prio3Count::new_count(2).unwrap(); - let task = TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(); + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count).build(); + + let leader_task = task.leader_view().unwrap(); let ephemeral_datastore = ephemeral_datastore().await; let datastore = Arc::new(ephemeral_datastore.datastore(clock.clone()).await); - datastore.put_task(&task).await.unwrap(); + datastore.put_aggregator_task(&leader_task).await.unwrap(); let aggregator = Aggregator::new(Arc::clone(&datastore), clock.clone(), &noop_meter(), cfg) .await @@ -3247,7 +3241,8 @@ mod tests { ..Default::default() }) .await; - let report = create_report(&task, clock.now()); + let leader_task = task.leader_view().unwrap(); + let report = create_report(&leader_task, clock.now()); aggregator .handle_upload(task.id(), &report.get_encoded()) @@ -3274,10 +3269,10 @@ mod tests { // Reports may not be mutated let mutated_report = create_report_custom( - &task, + &leader_task, clock.now(), *report.metadata().id(), - task.current_hpke_key(), + leader_task.current_hpke_key(), ); let error = aggregator .handle_upload(task.id(), &mutated_report.get_encoded()) @@ -3303,9 +3298,10 @@ mod tests { }) .await; - let reports: Vec<_> = iter::repeat_with(|| create_report(&task, clock.now())) - .take(BATCH_SIZE) - .collect(); + let reports: Vec<_> = + iter::repeat_with(|| create_report(&task.leader_view().unwrap(), clock.now())) + .take(BATCH_SIZE) + .collect(); let want_report_ids: HashSet<_> = reports.iter().map(|r| *r.metadata().id()).collect(); let aggregator = Arc::new(aggregator); @@ -3339,11 +3335,12 @@ mod tests { let (_, aggregator, clock, task, _, _ephemeral_datastore) = setup_upload_test(default_aggregator_config()).await; - let report = create_report(&task, clock.now()); + let leader_task = task.leader_view().unwrap(); + let report = create_report(&leader_task, clock.now()); let unused_hpke_config_id = (0..) .map(HpkeConfigId::from) - .find(|id| !task.hpke_keys().contains_key(id)) + .find(|id| !leader_task.hpke_keys().contains_key(id)) .unwrap(); let report = Report::new( @@ -3372,7 +3369,10 @@ mod tests { let (vdaf, aggregator, clock, task, datastore, _ephemeral_datastore) = setup_upload_test(default_aggregator_config()).await; - let report = create_report(&task, clock.now().add(task.tolerable_clock_skew()).unwrap()); + let report = create_report( + &task.leader_view().unwrap(), + clock.now().add(task.tolerable_clock_skew()).unwrap(), + ); aggregator .handle_upload(task.id(), &report.get_encoded()) @@ -3399,7 +3399,7 @@ mod tests { let (_, aggregator, clock, task, _, _ephemeral_datastore) = setup_upload_test(default_aggregator_config()).await; let report = create_report( - &task, + &task.leader_view().unwrap(), clock .now() .add(task.tolerable_clock_skew()) @@ -3426,7 +3426,7 @@ mod tests { let (_, aggregator, clock, task, datastore, _ephemeral_datastore) = setup_upload_test(default_aggregator_config()).await; - let report = create_report(&task, clock.now()); + let report = create_report(&task.leader_view().unwrap(), clock.now()); // Insert a collection job for the batch interval including our report. let batch_interval = Interval::new( @@ -3479,16 +3479,17 @@ mod tests { ..Default::default() }) .await; + let leader_task = task.leader_view().unwrap(); // Same ID as the task to test having both keys to choose from. let global_hpke_keypair_same_id = generate_test_hpke_config_and_private_key_with_id( - (*task.current_hpke_key().config().id()).into(), + (*leader_task.current_hpke_key().config().id()).into(), ); // Different ID to test misses on the task key. let global_hpke_keypair_different_id = generate_test_hpke_config_and_private_key_with_id( (0..) .map(HpkeConfigId::from) - .find(|id| !task.hpke_keys().contains_key(id)) + .find(|id| !leader_task.hpke_keys().contains_key(id)) .unwrap() .into(), ); @@ -3511,10 +3512,15 @@ mod tests { aggregator.refresh_caches().await.unwrap(); for report in [ - create_report(&task, clock.now()), - create_report_custom(&task, clock.now(), random(), &global_hpke_keypair_same_id), + create_report(&leader_task, clock.now()), create_report_custom( - &task, + &leader_task, + clock.now(), + random(), + &global_hpke_keypair_same_id, + ), + create_report_custom( + &leader_task, clock.now(), random(), &global_hpke_keypair_different_id, diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index bf162d8f5..48c80d9e5 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -674,7 +674,7 @@ mod tests { Datastore, }, query_type::{AccumulableQueryType, CollectableQueryType}, - task::{test_util::NewTaskBuilder as TaskBuilder, QueryType, Task, VerifyKey}, + task::{test_util::NewTaskBuilder as TaskBuilder, QueryType, VerifyKey}, test_util::noop_meter, }; use janus_core::{ @@ -1098,7 +1098,7 @@ mod tests { let leader_task = task.leader_view().unwrap(); datastore.put_aggregator_task(&leader_task).await.unwrap(); - let report = create_report(&Task::from(leader_task.clone()), clock.now()); + let report = create_report(&leader_task, clock.now()); // Upload a report. Do this twice to prove that PUT is idempotent. for _ in 0..2 { @@ -1116,7 +1116,7 @@ mod tests { // Verify that new reports using an existing report ID are rejected with reportRejected let duplicate_id_report = create_report_custom( - &Task::from(leader_task.clone()), + &leader_task, clock.now(), *accepted_report_id, leader_task.current_hpke_key(), @@ -1233,7 +1233,7 @@ mod tests { .await .unwrap(); let report_2 = create_report( - &Task::from(leader_task_expire_soon), + &leader_task_expire_soon, clock.now().add(&Duration::from_seconds(120)).unwrap(), ); let mut test_conn = put(task_expire_soon.report_upload_uri().unwrap().path()) @@ -1306,7 +1306,7 @@ mod tests { let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count).build(); let helper_task = task.helper_view().unwrap(); datastore.put_aggregator_task(&helper_task).await.unwrap(); - let report = create_report(&Task::from(helper_task), clock.now()); + let report = create_report(&helper_task, clock.now()); let mut test_conn = put(task.report_upload_uri().unwrap().path()) .with_request_header(KnownHeaderName::ContentType, Report::MEDIA_TYPE) diff --git a/aggregator/tests/graceful_shutdown.rs b/aggregator/tests/graceful_shutdown.rs index b1c90e064..dc8155cfa 100644 --- a/aggregator/tests/graceful_shutdown.rs +++ b/aggregator/tests/graceful_shutdown.rs @@ -6,10 +6,9 @@ use base64::{engine::general_purpose::STANDARD_NO_PAD, Engine}; use janus_aggregator_core::{ datastore::test_util::ephemeral_datastore, - task::{test_util::TaskBuilder, QueryType}, + task::{test_util::NewTaskBuilder as TaskBuilder, QueryType}, }; use janus_core::{test_util::install_test_trace_subscriber, time::RealClock, vdaf::VdafInstance}; -use janus_messages::Role; use reqwest::Url; use serde_yaml::{Mapping, Value}; use std::{ @@ -124,13 +123,11 @@ async fn graceful_shutdown(binary: &Path, mut config: Mapping) { format!("{health_check_listen_address}").into(), ); - let task = TaskBuilder::new( - QueryType::TimeInterval, - VdafInstance::Prio3Count, - Role::Leader, - ) - .build(); - datastore.put_task(&task).await.unwrap(); + let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Prio3Count) + .build() + .leader_view() + .unwrap(); + datastore.put_aggregator_task(&task).await.unwrap(); // Save the above configuration to a temporary file, so that we can pass // the file's path to the binary under test on the command line.