Skip to content

Commit

Permalink
Task rewrite: adopt AggregatorTask in datastore
Browse files Browse the repository at this point in the history
Adopts `janus_aggregator_core::task::test_util::NewTaskBuilder` and
`janus_aggregator_core::task::AggregatorTask` in the
`janus_aggregator_core::datastore` module. Much as the previous change
provides two kinds of `Task` structure, we now provide two sets of
methods for reading and writing tasks: one that deals in the new
`AggregatorTask` and the other which deals in the old `Task`.

We add routines for converting between `task::Task` and
`task::AggregatorTask` to make it easier for these two paths through the
datastore to co-exist. This conversion is lossy because `AggregatorTask`
only retains one of the aggregator endpoints, but this doesn't cause
substantial problems in Janus, and we can live it transitionally.

Part of #1524
  • Loading branch information
tgeoghegan committed Sep 28, 2023
1 parent 13132ea commit 460e3ec
Show file tree
Hide file tree
Showing 6 changed files with 518 additions and 343 deletions.
2 changes: 1 addition & 1 deletion aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ mod tests {
let task = TaskBuilder::new(
QueryType::TimeInterval,
VdafInstance::Prio3Count,
Role::Leader,
Role::Helper,
)
.build();
let task_id = *task.id();
Expand Down
4 changes: 3 additions & 1 deletion aggregator/src/aggregator/taskprov_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,9 @@ async fn taskprov_aggregate_init() {
.state()
.eq(&AggregationJobState::InProgress)
);
assert_eq!(test.task, got_task.unwrap());
// TODO(#1524): This assertion temporarily just checks the task ID because of the lossy
// conversion between task::Task and task::AggregatorTask.
assert_eq!(test.task.id(), got_task.unwrap().id());
}

#[tokio::test]
Expand Down
170 changes: 109 additions & 61 deletions aggregator_core/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use self::models::{
};
use crate::{
query_type::{AccumulableQueryType, CollectableQueryType},
task::{self, Task},
taskprov::{self, PeerAggregator},
task::{self, AggregatorTask, AggregatorTaskParameters, Task},
taskprov::PeerAggregator,
SecretBytes,
};
use chrono::NaiveDateTime;
Expand Down Expand Up @@ -306,6 +306,7 @@ impl<C: Clock> Datastore<C> {
}

/// Write a task into the datastore.
// TODO(#1524): remove this once everything has migrated to put_aggregator_task
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub async fn put_task(&self, task: &Task) -> Result<(), Error> {
Expand All @@ -315,6 +316,17 @@ impl<C: Clock> Datastore<C> {
})
.await
}

/// Write a task into the datastore.
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub async fn put_aggregator_task(&self, task: &AggregatorTask) -> Result<(), Error> {
self.run_tx(|tx| {
let task = task.clone();
Box::pin(async move { tx.put_aggregator_task(&task).await })
})
.await
}
}

fn check_error<T>(
Expand Down Expand Up @@ -525,20 +537,34 @@ impl<C: Clock> Transaction<'_, C> {
}

/// Writes a task into the datastore.
// TODO(#1524): remove this once everything has migrated to put_aggregator_task
#[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)]
pub async fn put_task(&self, task: &Task) -> Result<(), Error> {
let aggregator_task = match task.role() {
Role::Leader => task.leader_view()?,
Role::Helper => task
.helper_view()
.or_else(|_| task.taskprov_helper_view())?,
_ => return Err(Error::InvalidParameter("role must be aggregator")),
};

self.put_aggregator_task(&aggregator_task).await
}

/// Writes a task into the datastore.
#[tracing::instrument(skip(self, task), fields(task_id = ?task.id()), err)]
pub async fn put_aggregator_task(&self, task: &AggregatorTask) -> Result<(), Error> {
// Main task insert.
let stmt = self
.prepare_cached(
"INSERT INTO tasks (
task_id, aggregator_role, leader_aggregator_endpoint,
helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count,
task_expiration, report_expiry_age, min_batch_size, time_precision,
tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
task_id, aggregator_role, peer_aggregator_endpoint, query_type, vdaf,
max_batch_query_count, task_expiration, report_expiry_age, min_batch_size,
time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type,
collector_auth_token)
VALUES (
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17
)
ON CONFLICT DO NOTHING",
)
Expand All @@ -549,10 +575,8 @@ impl<C: Clock> Transaction<'_, C> {
&[
/* task_id */ &task.id().as_ref(),
/* aggregator_role */ &AggregatorRole::from_role(*task.role())?,
/* leader_aggregator_endpoint */
&task.leader_aggregator_endpoint().as_str(),
/* helper_aggregator_endpoint */
&task.helper_aggregator_endpoint().as_str(),
/* peer_aggregator_endpoint */
&task.peer_aggregator_endpoint().as_str(),
/* query_type */ &Json(task.query_type()),
/* vdaf */ &Json(task.vdaf()),
/* max_batch_query_count */
Expand All @@ -574,9 +598,7 @@ impl<C: Clock> Transaction<'_, C> {
/* tolerable_clock_skew */
&i64::try_from(task.tolerable_clock_skew().as_seconds())?,
/* collector_hpke_config */
&task
.collector_hpke_config()
.map(|config| config.get_encoded()),
&task.collector_hpke_config().map(|cfg| cfg.get_encoded()),
/* vdaf_verify_key */
&self.crypter.encrypt(
"tasks",
Expand Down Expand Up @@ -625,6 +647,7 @@ impl<C: Clock> Transaction<'_, C> {
let mut hpke_config_ids: Vec<i16> = Vec::new();
let mut hpke_configs: Vec<Vec<u8>> = Vec::new();
let mut hpke_private_keys: Vec<Vec<u8>> = Vec::new();

for hpke_keypair in task.hpke_keys().values() {
let mut row_id = [0u8; TaskId::LEN + size_of::<u8>()];
row_id[..TaskId::LEN].copy_from_slice(task.id().as_ref());
Expand Down Expand Up @@ -677,16 +700,26 @@ impl<C: Clock> Transaction<'_, C> {
}

/// Fetch the task parameters corresponing to the provided `task_id`.
// TODO(#1524): remove this once everything has migrated to get_aggregator_task
#[tracing::instrument(skip(self), err)]
pub async fn get_task(&self, task_id: &TaskId) -> Result<Option<Task>, Error> {
Ok(self.get_aggregator_task(task_id).await?.map(Task::from))
}

/// Fetch the task parameters corresponing to the provided `task_id`.
#[tracing::instrument(skip(self), err)]
pub async fn get_aggregator_task(
&self,
task_id: &TaskId,
) -> Result<Option<AggregatorTask>, Error> {
let params: &[&(dyn ToSql + Sync)] = &[&task_id.as_ref()];
let stmt = self
.prepare_cached(
"SELECT aggregator_role, leader_aggregator_endpoint, helper_aggregator_endpoint,
query_type, vdaf, max_batch_query_count, task_expiration, report_expiry_age,
min_batch_size, time_precision, tolerable_clock_skew, collector_hpke_config,
vdaf_verify_key, aggregator_auth_token_type, aggregator_auth_token,
collector_auth_token_type, collector_auth_token
"SELECT aggregator_role, peer_aggregator_endpoint, query_type, vdaf,
max_batch_query_count, task_expiration, report_expiry_age, min_batch_size,
time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type,
collector_auth_token
FROM tasks WHERE task_id = $1",
)
.await?;
Expand All @@ -707,14 +740,25 @@ impl<C: Clock> Transaction<'_, C> {
}

/// Fetch all the tasks in the database.
// TODO(#1524): remove this once everything has migrated to get_aggregator_tasks
#[tracing::instrument(skip(self), err)]
pub async fn get_tasks(&self) -> Result<Vec<Task>, Error> {
Ok(self
.get_aggregator_tasks()
.await?
.into_iter()
.map(Task::from)
.collect())
}

/// Fetch all the tasks in the database.
#[tracing::instrument(skip(self), err)]
pub async fn get_aggregator_tasks(&self) -> Result<Vec<AggregatorTask>, Error> {
let stmt = self
.prepare_cached(
"SELECT task_id, aggregator_role, leader_aggregator_endpoint,
helper_aggregator_endpoint, query_type, vdaf, max_batch_query_count,
task_expiration, report_expiry_age, min_batch_size, time_precision,
tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
"SELECT task_id, aggregator_role, peer_aggregator_endpoint, query_type, vdaf,
max_batch_query_count, task_expiration, report_expiry_age, min_batch_size,
time_precision, tolerable_clock_skew, collector_hpke_config, vdaf_verify_key,
aggregator_auth_token_type, aggregator_auth_token, collector_auth_token_type,
collector_auth_token
FROM tasks",
Expand Down Expand Up @@ -768,13 +812,10 @@ impl<C: Clock> Transaction<'_, C> {
task_id: &TaskId,
row: &Row,
hpke_key_rows: &[Row],
) -> Result<Task, Error> {
) -> Result<AggregatorTask, Error> {
// Scalar task parameters.
let aggregator_role: AggregatorRole = row.get("aggregator_role");
let leader_aggregator_endpoint =
row.get::<_, String>("leader_aggregator_endpoint").parse()?;
let helper_aggregator_endpoint =
row.get::<_, String>("helper_aggregator_endpoint").parse()?;
let peer_aggregator_endpoint = row.get::<_, String>("peer_aggregator_endpoint").parse()?;
let query_type = row.try_get::<_, Json<task::QueryType>>("query_type")?.0;
let vdaf = row.try_get::<_, Json<VdafInstance>>("vdaf")?.0;
let max_batch_query_count = row.get_bigint_and_convert("max_batch_query_count")?;
Expand Down Expand Up @@ -831,7 +872,7 @@ impl<C: Clock> Transaction<'_, C> {
.transpose()?;

// HPKE keys.
let mut hpke_keypairs = Vec::new();
let mut hpke_keys = Vec::new();
for row in hpke_key_rows {
let config_id = u8::try_from(row.get::<_, i16>("config_id"))?;
let config = HpkeConfig::get_decoded(row.get("config"))?;
Expand All @@ -848,50 +889,57 @@ impl<C: Clock> Transaction<'_, C> {
&encrypted_private_key,
)?);

hpke_keypairs.push(HpkeKeypair::new(config, private_key));
hpke_keys.push(HpkeKeypair::new(config, private_key));
}

let task = Task::new_without_validation(
let aggregator_parameters = match (
aggregator_role,
aggregator_auth_token,
collector_auth_token,
collector_hpke_config,
) {
(
AggregatorRole::Leader,
Some(aggregator_auth_token),
Some(collector_auth_token),
Some(collector_hpke_config),
) => AggregatorTaskParameters::Leader {
aggregator_auth_token,
collector_auth_token,
collector_hpke_config,
},
(
AggregatorRole::Helper,
Some(aggregator_auth_token),
None,
Some(collector_hpke_config),
) => AggregatorTaskParameters::Helper {
aggregator_auth_token,
collector_hpke_config,
},
(AggregatorRole::Helper, None, None, None) => AggregatorTaskParameters::TaskProvHelper,
values => {
return Err(Error::DbState(format!(
"found task row with unexpected combination of values {values:?}",
)));
}
};

Ok(AggregatorTask::new(
*task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
peer_aggregator_endpoint,
query_type,
vdaf,
aggregator_role.as_role(),
vdaf_verify_key,
max_batch_query_count,
task_expiration,
report_expiry_age,
min_batch_size,
time_precision,
tolerable_clock_skew,
collector_hpke_config,
aggregator_auth_token,
collector_auth_token,
hpke_keypairs,
);
// Trial validation through all known schemes. This is a workaround to avoid extending the
// schema to track the provenance of tasks. If we do end up implementing a task provenance
// column anyways, we can simplify this logic.
task.validate().or_else(|error| {
taskprov::Task(task.clone())
.validate()
.map_err(|taskprov_error| {
error!(
%task_id,
%error,
%taskprov_error,
?task,
"task has failed all available validation checks",
);
// Choose some error to bubble up to the caller. Either way this error
// occurring is an indication of a bug, which we'll need to go into the
// logs for.
error
})
})?;

Ok(task)
hpke_keys,
aggregator_parameters,
)?)
}

/// Retrieves report & report aggregation metrics for a given task: either a tuple
Expand Down
Loading

0 comments on commit 460e3ec

Please sign in to comment.