diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 827dc9219..d4a1e1619 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -496,6 +496,11 @@ impl Aggregator { Ok(Arc::clone(task_aggs.entry(*task_id).or_insert(task_agg))) } } + + #[cfg(feature = "test-util")] + pub async fn refresh_caches(&self) -> Result<(), Error> { + self.global_hpke_keypairs.refresh(&self.datastore).await + } } /// TaskAggregator provides aggregation functionality for a single task. @@ -2823,7 +2828,6 @@ mod tests { }; use rand::random; use std::{collections::HashSet, iter, sync::Arc, time::Duration as StdDuration}; - use tokio::time::sleep; pub(crate) const BATCH_AGGREGATION_SHARD_COUNT: u64 = 32; @@ -3217,9 +3221,7 @@ mod tests { }) .await .unwrap(); - - // Let keypair cache refresh. - sleep(StdDuration::from_millis(750)).await; + aggregator.refresh_caches().await.unwrap(); for report in [ create_report(&task, clock.now()), diff --git a/aggregator/src/aggregator/http_handlers.rs b/aggregator/src/aggregator/http_handlers.rs index c8abf8e1a..3c5d58970 100644 --- a/aggregator/src/aggregator/http_handlers.rs +++ b/aggregator/src/aggregator/http_handlers.rs @@ -208,7 +208,13 @@ pub async fn aggregator_handler( cfg: Config, ) -> Result { let aggregator = Arc::new(Aggregator::new(datastore, clock, meter, cfg).await?); + aggregator_handler_with_aggregator(aggregator, meter).await +} +async fn aggregator_handler_with_aggregator( + aggregator: Arc>, + meter: &Meter, +) -> Result { Ok(( State(aggregator), metrics("janus_aggregator").with_route(|conn| conn.route().map(ToString::to_string)), @@ -554,7 +560,7 @@ mod tests { }, collection_job_tests::setup_collection_job_test_case, empty_batch_aggregations, - http_handlers::aggregator_handler, + http_handlers::{aggregator_handler, aggregator_handler_with_aggregator}, tests::{ create_report, create_report_with_id, default_aggregator_config, generate_helper_report_share, generate_helper_report_share_for_plaintext, @@ -613,7 +619,6 @@ mod tests { use std::{ borrow::Cow, collections::HashMap, io::Cursor, sync::Arc, time::Duration as StdDuration, }; - use tokio::time::sleep; use trillium::{KnownHeaderName, Status}; use trillium_testing::{ assert_headers, @@ -727,7 +732,17 @@ mod tests { ..Default::default() }; - let handler = aggregator_handler(datastore.clone(), clock.clone(), &noop_meter(), cfg) + let aggregator = Arc::new( + crate::aggregator::Aggregator::new( + datastore.clone(), + clock.clone(), + &noop_meter(), + cfg, + ) + .await + .unwrap(), + ); + let handler = aggregator_handler_with_aggregator(aggregator.clone(), &noop_meter()) .await .unwrap(); @@ -756,7 +771,7 @@ mod tests { }) .await .unwrap(); - sleep(StdDuration::from_millis(750)).await; + aggregator.refresh_caches().await.unwrap(); let mut test_conn = get("/hpke_config").run_async(&handler).await; assert_eq!(test_conn.status(), Some(Status::Ok)); let bytes = take_response_body(&mut test_conn).await; @@ -777,7 +792,7 @@ mod tests { }) .await .unwrap(); - sleep(StdDuration::from_millis(750)).await; + aggregator.refresh_caches().await.unwrap(); let mut test_conn = get("/hpke_config").run_async(&handler).await; assert_eq!(test_conn.status(), Some(Status::Ok)); let bytes = take_response_body(&mut test_conn).await; @@ -813,7 +828,7 @@ mod tests { }) .await .unwrap(); - sleep(StdDuration::from_millis(750)).await; + aggregator.refresh_caches().await.unwrap(); let mut test_conn = get("/hpke_config").run_async(&handler).await; assert_eq!(test_conn.status(), Some(Status::Ok)); let bytes = take_response_body(&mut test_conn).await; @@ -831,7 +846,7 @@ mod tests { }) .await .unwrap(); - sleep(StdDuration::from_millis(750)).await; + aggregator.refresh_caches().await.unwrap(); let test_conn = get("/hpke_config").run_async(&handler).await; assert_eq!(test_conn.status(), Some(Status::BadRequest)); } diff --git a/aggregator/src/cache.rs b/aggregator/src/cache.rs index 855e8d96b..bdeac444c 100644 --- a/aggregator/src/cache.rs +++ b/aggregator/src/cache.rs @@ -1,20 +1,17 @@ //! Various in-memory caches that can be used by an aggregator. use crate::aggregator::Error; -use janus_aggregator_core::datastore::{ - models::{GlobalHpkeKeypair, HpkeKeyState}, - Datastore, -}; +use janus_aggregator_core::datastore::{models::HpkeKeyState, Datastore}; use janus_core::{hpke::HpkeKeypair, time::Clock}; use janus_messages::{HpkeConfig, HpkeConfigId}; use std::{ collections::HashMap, fmt::Debug, sync::{Arc, Mutex as StdMutex}, - time::Duration as StdDuration, + time::{Duration as StdDuration, Instant}, }; use tokio::{spawn, task::JoinHandle, time::sleep}; -use tracing::error; +use tracing::{debug, error}; type HpkeConfigs = Arc>; type HpkeKeypairs = HashMap>; @@ -41,34 +38,29 @@ impl GlobalHpkeKeypairCache { datastore: Arc>, refresh_interval: StdDuration, ) -> Result { + let keypairs = Arc::new(StdMutex::new(HashMap::new())); + let configs = Arc::new(StdMutex::new(Arc::new(Vec::new()))); + // Initial cache load. - let global_keypairs = Self::get_global_keypairs(&datastore).await?; - let configs = Arc::new(StdMutex::new(Self::filter_active_configs(&global_keypairs))); - let keypairs = Arc::new(StdMutex::new(Self::map_keypairs(&global_keypairs))); + Self::refresh_inner(&datastore, &configs, &keypairs).await?; // Start refresh task. let refresh_configs = configs.clone(); let refresh_keypairs = keypairs.clone(); + let refresh_datastore = datastore.clone(); let refresh_handle = spawn(async move { loop { sleep(refresh_interval).await; - match Self::get_global_keypairs(&datastore).await { - Ok(global_keypairs) => { - let new_configs = Self::filter_active_configs(&global_keypairs); - let new_keypairs = Self::map_keypairs(&global_keypairs); - { - let mut configs = refresh_configs.lock().unwrap(); - *configs = new_configs; - } - { - let mut keypairs = refresh_keypairs.lock().unwrap(); - *keypairs = new_keypairs; - } - } - Err(err) => { - error!(?err, "failed to refresh HPKE config cache"); - } + let now = Instant::now(); + let result = + Self::refresh_inner(&refresh_datastore, &refresh_configs, &refresh_keypairs) + .await; + let elapsed = now.elapsed(); + + match result { + Ok(_) => debug!(?elapsed, "successfully refreshed HPKE keypair cache"), + Err(err) => error!(?err, ?elapsed, "failed to refresh HPKE keypair cache"), } } }); @@ -80,8 +72,18 @@ impl GlobalHpkeKeypairCache { }) } - fn filter_active_configs(global_keypairs: &[GlobalHpkeKeypair]) -> HpkeConfigs { - Arc::new( + async fn refresh_inner( + datastore: &Datastore, + configs: &StdMutex, + keypairs: &StdMutex, + ) -> Result<(), Error> { + let global_keypairs = datastore + .run_tx_with_name("refresh_global_hpke_keypairs_cache", |tx| { + Box::pin(async move { tx.get_global_hpke_keypairs().await }) + }) + .await?; + + let new_configs = Arc::new( global_keypairs .iter() .filter_map(|keypair| match keypair.state() { @@ -89,27 +91,30 @@ impl GlobalHpkeKeypairCache { _ => None, }) .collect(), - ) - } + ); - fn map_keypairs(global_keypairs: &[GlobalHpkeKeypair]) -> HpkeKeypairs { - global_keypairs + let new_keypairs = global_keypairs .iter() .map(|keypair| { let keypair = keypair.hpke_keypair().clone(); (*keypair.config().id(), Arc::new(keypair)) }) - .collect() + .collect(); + + { + let mut configs = configs.lock().unwrap(); + *configs = new_configs; + } + { + let mut keypairs = keypairs.lock().unwrap(); + *keypairs = new_keypairs; + } + Ok(()) } - async fn get_global_keypairs( - datastore: &Datastore, - ) -> Result, Error> { - Ok(datastore - .run_tx_with_name("refresh_global_hpke_configs_cache", |tx| { - Box::pin(async move { tx.get_global_hpke_keypairs().await }) - }) - .await?) + #[cfg(feature = "test-util")] + pub async fn refresh(&self, datastore: &Datastore) -> Result<(), Error> { + Self::refresh_inner(datastore, &self.configs, &self.keypairs).await } /// Retrieve active configs for config advertisement. This only returns configs