Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Helper aggregation-initialization report-replayed check. #3143

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 9 additions & 24 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use fixed::{
types::extra::{U15, U31},
FixedI16, FixedI32,
};
use futures::{future::try_join_all, TryFutureExt as _};
use futures::future::try_join_all;
use http::{header::CONTENT_TYPE, Method};
use itertools::iproduct;
use janus_aggregator_core::{
Expand Down Expand Up @@ -2336,33 +2336,18 @@ impl VdafOps {
}

// Write report shares, and ensure this isn't a repeated report aggregation.
// TODO(#225): on repeated aggregation, verify input share matches previously-received input share
try_join_all(report_share_data.iter_mut().map(|rsd| {
let task = Arc::clone(&task);
let aggregation_job = Arc::clone(&aggregation_job);

async move {
let put_report_share_fut = tx
.put_scrubbed_report(task.id(), &rsd.report_share)
.or_else(|err| async move {
match err {
datastore::Error::MutationTargetAlreadyExists => Ok(()),
_ => Err(err),
}
});
// Verify that we haven't seen this report ID and aggregation parameter
// before in another aggregation job.
let report_aggregation_exists_fut = tx
.check_other_report_aggregation_exists::<SEED_SIZE, A>(
task.id(),
rsd.report_share.metadata().id(),
aggregation_job.aggregation_parameter(),
aggregation_job.id(),
);
let (_, report_aggregation_exists) =
try_join!(put_report_share_fut, report_aggregation_exists_fut)?;

if report_aggregation_exists {
let report_already_aggregated =
match tx.put_scrubbed_report(task.id(), &rsd.report_share).await {
Ok(()) => false,
Err(datastore::Error::MutationTargetAlreadyExists) => true,
Err(err) => return Err(err),
};

if report_already_aggregated {
rsd.report_aggregation = rsd
.report_aggregation
.clone()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,29 +354,21 @@ async fn aggregate_init() {
transcript_7.leader_prepare_transitions[0].message.clone(),
);

// prepare_init_8 has already been aggregated in another aggregation job, with a different
// aggregation parameter.
let (prepare_init_8, transcript_8) = prep_init_generator.next(&measurement);

let mut batch_aggregations_results = vec![];
let mut aggregation_jobs_results = vec![];
let (conflicting_aggregation_job, non_conflicting_aggregation_job) = datastore
let conflicting_aggregation_job = datastore
.run_unnamed_tx(|tx| {
let task = helper_task.clone();
let report_share_4 = prepare_init_4.report_share().clone();
let report_share_8 = prepare_init_8.report_share().clone();

Box::pin(async move {
tx.put_aggregator_task(&task).await.unwrap();

// report_share_4 and report_share_8 are already in the datastore as they were
// referenced by existing aggregation jobs.
// report_share_4 is already in the datastore as it was referenced by an existing
// aggregation job.
tx.put_scrubbed_report(task.id(), &report_share_4)
.await
.unwrap();
tx.put_scrubbed_report(task.id(), &report_share_8)
.await
.unwrap();

// Put in an aggregation job and report aggregation for report_share_4. It uses
// the same aggregation parameter as the aggregation job this test will later
Expand Down Expand Up @@ -408,37 +400,6 @@ async fn aggregate_init() {
.await
.unwrap();

// Put in an aggregation job and report aggregation for report_share_8, using a
// a different aggregation parameter. As the aggregation parameter differs,
// report_share_8 should prepare successfully in the aggregation job we'll PUT
// later.
let non_conflicting_aggregation_job = AggregationJob::new(
*task.id(),
random(),
dummy::AggregationParam(1),
(),
Interval::new(Time::from_seconds_since_epoch(0), Duration::from_seconds(1))
.unwrap(),
AggregationJobState::InProgress,
AggregationJobStep::from(0),
);
tx.put_aggregation_job::<0, TimeInterval, dummy::Vdaf>(
&non_conflicting_aggregation_job,
)
.await
.unwrap();
tx.put_report_aggregation::<0, dummy::Vdaf>(&ReportAggregation::new(
*task.id(),
*non_conflicting_aggregation_job.id(),
*report_share_8.metadata().id(),
*report_share_8.metadata().time(),
0,
None,
ReportAggregationState::Finished,
))
.await
.unwrap();

// Write collected batch aggregations for the interval that report_share_5 falls
// into, which will cause it to fail to prepare.
try_join_all(
Expand All @@ -456,7 +417,7 @@ async fn aggregate_init() {
.await
.unwrap();

Ok((conflicting_aggregation_job, non_conflicting_aggregation_job))
Ok(conflicting_aggregation_job)
})
})
.await
Expand All @@ -475,7 +436,6 @@ async fn aggregate_init() {
prepare_init_5.clone(),
prepare_init_6.clone(),
prepare_init_7.clone(),
prepare_init_8.clone(),
]),
);

Expand All @@ -492,7 +452,7 @@ async fn aggregate_init() {
let aggregate_resp: AggregationJobResp = decode_response_body(&mut test_conn).await;

// Validate response.
assert_eq!(aggregate_resp.prepare_resps().len(), 9);
assert_eq!(aggregate_resp.prepare_resps().len(), 8);

let prepare_step_0 = aggregate_resp.prepare_resps().first().unwrap();
assert_eq!(
Expand Down Expand Up @@ -573,15 +533,6 @@ async fn aggregate_init() {
&PrepareStepResult::Reject(PrepareError::InvalidMessage),
);

let prepare_step_8 = aggregate_resp.prepare_resps().get(8).unwrap();
assert_eq!(
prepare_step_8.report_id(),
prepare_init_8.report_share().metadata().id()
);
assert_matches!(prepare_step_8.result(), PrepareStepResult::Continue { message } => {
assert_eq!(message, &transcript_8.helper_prepare_transitions[0].message);
});

// Check aggregation job in datastore.
let (aggregation_jobs, batch_aggregations) = datastore
.run_unnamed_tx(|tx| {
Expand All @@ -601,17 +552,14 @@ async fn aggregate_init() {
.await
.unwrap();

assert_eq!(aggregation_jobs.len(), 3);
assert_eq!(aggregation_jobs.len(), 2);

let mut saw_conflicting_aggregation_job = false;
let mut saw_non_conflicting_aggregation_job = false;
let mut saw_new_aggregation_job = false;

for aggregation_job in &aggregation_jobs {
if aggregation_job.eq(&conflicting_aggregation_job) {
saw_conflicting_aggregation_job = true;
} else if aggregation_job.eq(&non_conflicting_aggregation_job) {
saw_non_conflicting_aggregation_job = true;
} else if aggregation_job.task_id().eq(task.id())
&& aggregation_job.id().eq(&aggregation_job_id)
&& aggregation_job.partial_batch_identifier().eq(&())
Expand All @@ -622,7 +570,6 @@ async fn aggregate_init() {
}

assert!(saw_conflicting_aggregation_job);
assert!(saw_non_conflicting_aggregation_job);
assert!(saw_new_aggregation_job);

aggregation_jobs_results.push(aggregation_jobs);
Expand Down
47 changes: 0 additions & 47 deletions aggregator_core/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2170,53 +2170,6 @@ WHERE aggregation_jobs.task_id = $6
)
}

/// Check whether the report has ever been aggregated with the given parameter, for an
/// aggregation job besides the given one.
#[tracing::instrument(skip(self), err(level = Level::DEBUG))]
pub async fn check_other_report_aggregation_exists<const SEED_SIZE: usize, A>(
&self,
task_id: &TaskId,
report_id: &ReportId,
aggregation_param: &A::AggregationParam,
aggregation_job_id: &AggregationJobId,
) -> Result<bool, Error>
where
A: vdaf::Aggregator<SEED_SIZE, 16>,
{
let task_info = match self.task_info_for(task_id).await? {
Some(task_info) => task_info,
None => return Ok(false),
};

let stmt = self
.prepare_cached(
"-- check_other_report_aggregation_exists()
SELECT 1 FROM report_aggregations
JOIN aggregation_jobs ON aggregation_jobs.id = report_aggregations.aggregation_job_id
WHERE report_aggregations.task_id = $1
AND aggregation_jobs.task_id = $1
AND report_aggregations.client_report_id = $2
AND aggregation_jobs.aggregation_param = $3
AND aggregation_jobs.aggregation_job_id != $4
AND UPPER(aggregation_jobs.client_timestamp_interval) >= $5",
)
.await?;
Ok(self
.query_opt(
&stmt,
&[
/* task_id */ &task_info.pkey,
/* report_id */ &report_id.as_ref(),
/* aggregation_param */ &aggregation_param.get_encoded()?,
/* aggregation_job_id */ &aggregation_job_id.as_ref(),
/* threshold */
&task_info.report_expiry_threshold(&self.clock.now().as_naive_date_time()?)?,
],
)
.await
.map(|row| row.is_some())?)
}

/// get_report_aggregations_for_aggregation_job retrieves all report aggregations associated
/// with a given aggregation job, ordered by their natural ordering.
#[tracing::instrument(skip(self), err(level = Level::DEBUG))]
Expand Down
Loading
Loading