Skip to content

Commit

Permalink
Job driver: wait for in-flight jobs upon SIGTERM
Browse files Browse the repository at this point in the history
  • Loading branch information
divergentdave committed Jul 3, 2023
1 parent 710075a commit d635490
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 129 deletions.
89 changes: 50 additions & 39 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -942,6 +942,7 @@ mod tests {
use rand::random;
use reqwest::Url;
use std::{borrow::Borrow, str, sync::Arc, time::Duration as StdDuration};
use trillium_tokio::Stopper;

#[tokio::test]
async fn aggregation_job_driver() {
Expand Down Expand Up @@ -1117,33 +1118,39 @@ mod tests {
&meter,
32,
));
let stopper = Stopper::new();

// Run. Let the aggregation job driver step aggregation jobs, then kill it.
let aggregation_job_driver = Arc::new(JobDriver::new(
clock,
runtime_manager.with_label("stepper"),
meter,
StdDuration::from_secs(1),
StdDuration::from_secs(1),
10,
StdDuration::from_secs(60),
aggregation_job_driver.make_incomplete_job_acquirer_callback(
Arc::clone(&ds),
StdDuration::from_secs(600),
),
aggregation_job_driver.make_job_stepper_callback(Arc::clone(&ds), 5),
));
let aggregation_job_driver = Arc::new(
JobDriver::new(
clock,
runtime_manager.with_label("stepper"),
meter,
stopper.clone(),
StdDuration::from_secs(1),
StdDuration::from_secs(1),
10,
StdDuration::from_secs(60),
aggregation_job_driver.make_incomplete_job_acquirer_callback(
Arc::clone(&ds),
StdDuration::from_secs(600),
),
aggregation_job_driver.make_job_stepper_callback(Arc::clone(&ds), 5),
)
.unwrap(),
);

let task_handle = runtime_manager.with_label("driver").spawn({
let aggregation_job_driver = aggregation_job_driver.clone();
async move { aggregation_job_driver.run().await }
});
let task_handle = runtime_manager
.with_label("driver")
.spawn(aggregation_job_driver.run());

tracing::info!("awaiting stepper tasks");
// Wait for all of the aggregate job stepper tasks to complete.
// Wait for all of the aggregation job stepper tasks to complete.
runtime_manager.wait_for_completed_tasks("stepper", 2).await;
// Stop the aggregate job driver task.
task_handle.abort();
// Stop the aggregation job driver.
stopper.stop();
// Wait for the aggregation job driver task to complete.
task_handle.await.unwrap();

// Verify.
for mocked_aggregate in mocked_aggregates {
Expand Down Expand Up @@ -2775,6 +2782,7 @@ mod tests {
let mut runtime_manager = TestRuntimeManager::new();
let ephemeral_datastore = ephemeral_datastore().await;
let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);
let stopper = Stopper::new();

let task = TaskBuilder::new(
QueryType::TimeInterval,
Expand Down Expand Up @@ -2876,20 +2884,24 @@ mod tests {
&meter,
32,
));
let job_driver = Arc::new(JobDriver::new(
clock.clone(),
runtime_manager.with_label("stepper"),
meter,
StdDuration::from_secs(1),
StdDuration::from_secs(1),
10,
StdDuration::from_secs(60),
aggregation_job_driver.make_incomplete_job_acquirer_callback(
Arc::clone(&ds),
StdDuration::from_secs(600),
),
aggregation_job_driver.make_job_stepper_callback(Arc::clone(&ds), 3),
));
let job_driver = Arc::new(
JobDriver::new(
clock.clone(),
runtime_manager.with_label("stepper"),
meter,
stopper.clone(),
StdDuration::from_secs(1),
StdDuration::from_secs(1),
10,
StdDuration::from_secs(60),
aggregation_job_driver.make_incomplete_job_acquirer_callback(
Arc::clone(&ds),
StdDuration::from_secs(600),
),
aggregation_job_driver.make_job_stepper_callback(Arc::clone(&ds), 3),
)
.unwrap(),
);

// Set up three error responses from our mock helper. These will cause errors in the
// leader, because the response body is empty and cannot be decoded.
Expand Down Expand Up @@ -2936,9 +2948,7 @@ mod tests {
.await;

// Start up the job driver.
let task_handle = runtime_manager
.with_label("driver")
.spawn(async move { job_driver.run().await });
let task_handle = runtime_manager.with_label("driver").spawn(job_driver.run());

// Run the job driver until we try to step the collection job four times. The first three
// attempts make network requests and fail, while the fourth attempt just marks the job
Expand All @@ -2950,7 +2960,8 @@ mod tests {
// and try again.
clock.advance(&Duration::from_seconds(600));
}
task_handle.abort();
stopper.stop();
task_handle.await.unwrap();

// Check that the job driver made the HTTP requests we expected.
failure_mock.assert_async().await;
Expand Down
41 changes: 23 additions & 18 deletions aggregator/src/aggregator/collection_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ mod tests {
use prio::codec::{Decode, Encode};
use rand::random;
use std::{str, sync::Arc, time::Duration as StdDuration};
use trillium_tokio::Stopper;
use url::Url;

async fn setup_collection_job_test_case(
Expand Down Expand Up @@ -1007,6 +1008,7 @@ mod tests {
let mut runtime_manager = TestRuntimeManager::new();
let ephemeral_datastore = ephemeral_datastore().await;
let ds = Arc::new(ephemeral_datastore.datastore(clock.clone()).await);
let stopper = Stopper::new();

let (task, _, collection_job) =
setup_collection_job_test_case(&mut server, clock.clone(), Arc::clone(&ds), false)
Expand All @@ -1016,20 +1018,24 @@ mod tests {
let meter = meter("collection_job_driver");
let collection_job_driver =
Arc::new(CollectionJobDriver::new(reqwest::Client::new(), &meter, 1));
let job_driver = Arc::new(JobDriver::new(
clock.clone(),
runtime_manager.with_label("stepper"),
meter,
StdDuration::from_secs(1),
StdDuration::from_secs(1),
10,
StdDuration::from_secs(60),
collection_job_driver.make_incomplete_job_acquirer_callback(
Arc::clone(&ds),
StdDuration::from_secs(600),
),
collection_job_driver.make_job_stepper_callback(Arc::clone(&ds), 3),
));
let job_driver = Arc::new(
JobDriver::new(
clock.clone(),
runtime_manager.with_label("stepper"),
meter,
stopper.clone(),
StdDuration::from_secs(1),
StdDuration::from_secs(1),
10,
StdDuration::from_secs(60),
collection_job_driver.make_incomplete_job_acquirer_callback(
Arc::clone(&ds),
StdDuration::from_secs(600),
),
collection_job_driver.make_job_stepper_callback(Arc::clone(&ds), 3),
)
.unwrap(),
);

// Set up three error responses from our mock helper. These will cause errors in the
// leader, because the response body is empty and cannot be decoded.
Expand All @@ -1050,9 +1056,7 @@ mod tests {
.await;

// Start up the job driver.
let task_handle = runtime_manager
.with_label("driver")
.spawn(async move { job_driver.run().await });
let task_handle = runtime_manager.with_label("driver").spawn(job_driver.run());

// Run the job driver until we try to step the collection job four times. The first three
// attempts make network requests and fail, while the fourth attempt just marks the job
Expand All @@ -1065,7 +1069,8 @@ mod tests {
clock.advance(&Duration::from_seconds(600));
}
// Shut down the job driver.
task_handle.abort();
stopper.stop();
task_handle.await.unwrap();

// Check that the job driver made the HTTP requests we expected.
failure_mock.assert_async().await;
Expand Down
5 changes: 3 additions & 2 deletions aggregator/src/bin/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async fn main() -> anyhow::Result<()> {
ctx.clock,
TokioRuntime,
meter,
stopper,
Duration::from_secs(ctx.config.job_driver_config.min_job_discovery_delay_secs),
Duration::from_secs(ctx.config.job_driver_config.max_job_discovery_delay_secs),
ctx.config.job_driver_config.max_concurrent_job_workers,
Expand All @@ -57,8 +58,8 @@ async fn main() -> anyhow::Result<()> {
Arc::clone(&datastore),
ctx.config.job_driver_config.maximum_attempts_before_failure,
),
));
stopper.stop_future(job_driver.run()).await;
)?);
job_driver.run().await;

Ok(())
})
Expand Down
5 changes: 3 additions & 2 deletions aggregator/src/bin/collection_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async fn main() -> anyhow::Result<()> {
ctx.clock,
TokioRuntime,
meter,
stopper,
Duration::from_secs(ctx.config.job_driver_config.min_job_discovery_delay_secs),
Duration::from_secs(ctx.config.job_driver_config.max_job_discovery_delay_secs),
ctx.config.job_driver_config.max_concurrent_job_workers,
Expand All @@ -57,8 +58,8 @@ async fn main() -> anyhow::Result<()> {
Arc::clone(&datastore),
ctx.config.job_driver_config.maximum_attempts_before_failure,
),
));
stopper.stop_future(job_driver.run()).await;
)?);
job_driver.run().await;

Ok(())
})
Expand Down
Loading

0 comments on commit d635490

Please sign in to comment.