Skip to content

Commit

Permalink
Task rewrite: Aggregator API
Browse files Browse the repository at this point in the history
Adopt `janus_aggregator_core::task::AggregatorTask` throughout
`janus_aggregator_api`. Gratifyingly, this removes some hacks that
resolved impedance mismatches between the `PostTaskReq` and
`janus_aggregator_core::task::Task`, which are no longer needed because
`AggregatorTask` only has the peer aggregator's endpoint in it.

Part of #1524
  • Loading branch information
tgeoghegan committed Sep 29, 2023
1 parent 7224d51 commit 2d1e6c2
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 98 deletions.
21 changes: 4 additions & 17 deletions aggregator_api/src/models.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use janus_aggregator_core::{
datastore::models::{GlobalHpkeKeypair, HpkeKeyState},
task::{QueryType, Task},
task::{AggregatorTask, QueryType},
taskprov::{PeerAggregator, VerifyKeyInit},
};
use janus_core::{auth_tokens::AuthenticationToken, vdaf::VdafInstance};
Expand Down Expand Up @@ -125,23 +125,10 @@ pub(crate) struct TaskResp {
pub(crate) aggregator_hpke_configs: Vec<HpkeConfig>,
}

impl TryFrom<&Task> for TaskResp {
impl TryFrom<&AggregatorTask> for TaskResp {
type Error = &'static str;

fn try_from(task: &Task) -> Result<Self, Self::Error> {
// We have to resolve impedance mismatches between the aggregator API's view of a task
// and `aggregator_core::task::Task`. For now, we deal with this in code, but someday
// the two representations will be harmonized.
// https://github.com/divviup/janus/issues/1524

// Return the aggregator endpoint URL for the role opposite our own
let peer_aggregator_endpoint = match task.role() {
Role::Leader => task.helper_aggregator_endpoint(),
Role::Helper => task.leader_aggregator_endpoint(),
_ => return Err("illegal aggregator role in task"),
}
.clone();

fn try_from(task: &AggregatorTask) -> Result<Self, Self::Error> {
let mut aggregator_hpke_configs: Vec<_> = task
.hpke_keys()
.values()
Expand All @@ -151,7 +138,7 @@ impl TryFrom<&Task> for TaskResp {

Ok(Self {
task_id: *task.id(),
peer_aggregator_endpoint,
peer_aggregator_endpoint: task.peer_aggregator_endpoint().clone(),
query_type: *task.query_type(),
vdaf: task.vdaf().clone(),
role: *task.role(),
Expand Down
71 changes: 27 additions & 44 deletions aggregator_api/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use janus_aggregator_core::{
datastore::{self, Datastore},
task::Task,
task::{AggregatorTask, AggregatorTaskParameters},
taskprov::PeerAggregator,
SecretBytes,
};
Expand All @@ -26,7 +26,6 @@ use ring::digest::{digest, SHA256};
use std::{str::FromStr, sync::Arc, unreachable};
use trillium::{Conn, Status};
use trillium_api::{Json, State};
use url::Url;

pub(super) async fn get_config(
_: &mut Conn,
Expand Down Expand Up @@ -79,27 +78,10 @@ pub(super) async fn post_task<C: Clock>(
_: &mut Conn,
(State(ds), Json(req)): (State<Arc<Datastore<C>>>, Json<PostTaskReq>),
) -> Result<Json<TaskResp>, Error> {
// We have to resolve impedance mismatches between the aggregator API's view of a task and
// `aggregator_core::task::Task`. For now, we deal with this in code, but someday the two
// representations will be harmonized.
// https://github.com/divviup/janus/issues/1524

if !matches!(req.role, Role::Leader | Role::Helper) {
return Err(Error::BadRequest(format!("invalid role {}", req.role)));
}

// struct `aggregator_core::task::Task` expects to get two aggregator endpoint URLs, but only
// the one for the peer aggregator is in the incoming request (or for that matter, is ever used
// by Janus), so we insert a fake URL for "self".
// TODO(#1524): clean this up with `aggregator_core::task::Task` changes
// unwrap safety: this fake URL is valid
let fake_aggregator_url = Url::parse("http://never-used.example.com").unwrap();
let (leader_aggregator_endpoint, helper_aggregator_endpoint) = match req.role {
Role::Leader => (fake_aggregator_url, req.peer_aggregator_endpoint),
Role::Helper => (req.peer_aggregator_endpoint, fake_aggregator_url),
_ => unreachable!(),
};

let vdaf_verify_key_bytes = URL_SAFE_NO_PAD
.decode(&req.vdaf_verify_key)
.map_err(|err| {
Expand All @@ -121,15 +103,19 @@ pub(super) async fn post_task<C: Clock>(

let vdaf_verify_key = SecretBytes::new(vdaf_verify_key_bytes);

let (aggregator_auth_token, collector_auth_token) = match req.role {
let aggregator_parameters = match req.role {
Role::Leader => {
let aggregator_auth_token = req.aggregator_auth_token.ok_or_else(|| {
Error::BadRequest(
"aggregator acting in leader role must be provided an aggregator auth token"
.to_string(),
)
})?;
(Some(aggregator_auth_token), Some(random()))
AggregatorTaskParameters::Leader {
aggregator_auth_token,
collector_auth_token: random(),
collector_hpke_config: req.collector_hpke_config,
}
}

Role::Helper => {
Expand All @@ -140,29 +126,21 @@ pub(super) async fn post_task<C: Clock>(
));
}

(Some(random()), None)
AggregatorTaskParameters::Helper {
aggregator_auth_token: random(),
collector_hpke_config: req.collector_hpke_config,
}
}

_ => unreachable!(),
};

// Unwrap safety: we always use a supported KEM.
let hpke_keys = Vec::from([generate_hpke_config_and_private_key(
random(),
HpkeKemId::X25519HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
)
.unwrap()]);

let task = Arc::new(
Task::new(
AggregatorTask::new(
task_id,
leader_aggregator_endpoint,
helper_aggregator_endpoint,
/* peer_aggregator_endpoint */ req.peer_aggregator_endpoint,
/* query_type */ req.query_type,
/* vdaf */ req.vdaf,
/* role */ req.role,
vdaf_verify_key,
/* max_batch_query_count */ req.max_batch_query_count,
/* task_expiration */ req.task_expiration,
Expand All @@ -172,22 +150,27 @@ pub(super) async fn post_task<C: Clock>(
/* time_precision */ req.time_precision,
/* tolerable_clock_skew */
Duration::from_seconds(60), // 1 minute,
/* collector_hpke_config */ req.collector_hpke_config,
aggregator_auth_token,
collector_auth_token,
hpke_keys,
// hpke_keys
// Unwrap safety: we always use a supported KEM.
[generate_hpke_config_and_private_key(
random(),
HpkeKemId::X25519HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
)
.unwrap()],
aggregator_parameters,
)
.map_err(|err| Error::BadRequest(format!("Error constructing task: {err}")))?,
);

ds.run_tx_with_name("post_task", |tx| {
let task = Arc::clone(&task);
Box::pin(async move {
if let Some(existing_task) = tx.get_task(task.id()).await? {
if let Some(existing_task) = tx.get_aggregator_task(task.id()).await? {
// Check whether the existing task in the DB corresponds to the incoming task, ignoring
// those fields that are randomly generated.
if existing_task.leader_aggregator_endpoint() == task.leader_aggregator_endpoint()
&& existing_task.helper_aggregator_endpoint() == task.helper_aggregator_endpoint()
if existing_task.peer_aggregator_endpoint() == task.peer_aggregator_endpoint()
&& existing_task.query_type() == task.query_type()
&& existing_task.vdaf() == task.vdaf()
&& existing_task.opaque_vdaf_verify_key() == task.opaque_vdaf_verify_key()
Expand All @@ -206,7 +189,7 @@ pub(super) async fn post_task<C: Clock>(
return Err(datastore::Error::User(err.into()));
}

tx.put_task(&task).await
tx.put_aggregator_task(&task).await
})
})
.await?;
Expand All @@ -224,7 +207,7 @@ pub(super) async fn get_task<C: Clock>(

let task = ds
.run_tx_with_name("get_task", |tx| {
Box::pin(async move { tx.get_task(&task_id).await })
Box::pin(async move { tx.get_aggregator_task(&task_id).await })
})
.await?
.ok_or(Error::NotFound)?;
Expand Down
82 changes: 45 additions & 37 deletions aggregator_api/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ use janus_aggregator_core::{
test_util::{ephemeral_datastore, EphemeralDatastore},
Datastore,
},
task::{test_util::TaskBuilder, QueryType, Task},
task::{
test_util::NewTaskBuilder as TaskBuilder, AggregatorTask, AggregatorTaskParameters,
QueryType,
},
taskprov::test_util::PeerAggregatorBuilder,
SecretBytes,
};
Expand Down Expand Up @@ -99,13 +102,15 @@ async fn get_task_ids() {
.run_tx(|tx| {
Box::pin(async move {
let tasks: Vec<_> = iter::repeat_with(|| {
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader)
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake)
.build()
.leader_view()
.unwrap()
})
.take(10)
.collect();

try_join_all(tasks.iter().map(|task| tx.put_task(task))).await?;
try_join_all(tasks.iter().map(|task| tx.put_aggregator_task(task))).await?;

Ok(tasks.into_iter().map(|task| *task.id()).collect())
})
Expand Down Expand Up @@ -316,17 +321,16 @@ async fn post_task_helper_no_optional_fields() {
let got_task = ds
.run_tx(|tx| {
let got_task_resp = got_task_resp.clone();
Box::pin(async move { tx.get_task(&got_task_resp.task_id).await })
Box::pin(async move { tx.get_aggregator_task(&got_task_resp.task_id).await })
})
.await
.unwrap()
.expect("task was not created");

// Verify that the task written to the datastore matches the request...
assert_eq!(
// The other aggregator endpoint in the datastore task is fake
&req.peer_aggregator_endpoint,
got_task.leader_aggregator_endpoint()
got_task.peer_aggregator_endpoint()
);
assert_eq!(&req.query_type, got_task.query_type());
assert_eq!(&req.vdaf, got_task.vdaf());
Expand Down Expand Up @@ -521,17 +525,16 @@ async fn post_task_leader_all_optional_fields() {
let got_task = ds
.run_tx(|tx| {
let got_task_resp = got_task_resp.clone();
Box::pin(async move { tx.get_task(&got_task_resp.task_id).await })
Box::pin(async move { tx.get_aggregator_task(&got_task_resp.task_id).await })
})
.await
.unwrap()
.expect("task was not created");

// Verify that the task written to the datastore matches the request...
assert_eq!(
// The other aggregator endpoint in the datastore task is fake
&req.peer_aggregator_endpoint,
got_task.helper_aggregator_endpoint()
got_task.peer_aggregator_endpoint()
);
assert_eq!(&req.query_type, got_task.query_type());
assert_eq!(&req.vdaf, got_task.vdaf());
Expand Down Expand Up @@ -603,12 +606,15 @@ async fn get_task() {
// Setup: write a task to the datastore.
let (handler, _ephemeral_datastore, ds) = setup_api_test().await;

let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader).build();
let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake)
.build()
.leader_view()
.unwrap();

ds.run_tx(|tx| {
let task = task.clone();
Box::pin(async move {
tx.put_task(&task).await?;
tx.put_aggregator_task(&task).await?;
Ok(())
})
})
Expand Down Expand Up @@ -664,11 +670,12 @@ async fn delete_task() {
let task_id = ds
.run_tx(|tx| {
Box::pin(async move {
let task =
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader)
.build();
let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake)
.build()
.leader_view()
.unwrap();

tx.put_task(&task).await?;
tx.put_aggregator_task(&task).await?;

Ok(*task.id())
})
Expand Down Expand Up @@ -739,11 +746,12 @@ async fn get_task_metrics() {
let task_id = ds
.run_tx(|tx| {
Box::pin(async move {
let task =
TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake, Role::Leader)
.build();
let task = TaskBuilder::new(QueryType::TimeInterval, VdafInstance::Fake)
.build()
.leader_view()
.unwrap();
let task_id = *task.id();
tx.put_task(&task).await?;
tx.put_aggregator_task(&task).await?;

let reports: Vec<_> = iter::repeat_with(|| {
LeaderStoredReport::new_dummy(task_id, Time::from_seconds_since_epoch(0))
Expand Down Expand Up @@ -1742,9 +1750,8 @@ fn post_task_req_serialization() {

#[test]
fn task_resp_serialization() {
let task = Task::new(
let task = AggregatorTask::new(
TaskId::from([0u8; 32]),
"https://leader.com/".parse().unwrap(),
"https://helper.com/".parse().unwrap(),
QueryType::FixedSize {
max_batch_size: 999,
Expand All @@ -1754,29 +1761,13 @@ fn task_resp_serialization() {
length: 5,
chunk_length: 2,
},
Role::Leader,
SecretBytes::new(b"vdaf verify key!".to_vec()),
1,
None,
None,
100,
Duration::from_seconds(3600),
Duration::from_seconds(60),
HpkeConfig::new(
HpkeConfigId::from(7),
HpkeKemId::X25519HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
HpkePublicKey::from([0u8; 32].to_vec()),
),
Some(
AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw")
.unwrap(),
),
Some(
AuthenticationToken::new_dap_auth_token_from_string("Y29sbGVjdG9yLWFiY2RlZjAw")
.unwrap(),
),
[(HpkeKeypair::new(
HpkeConfig::new(
HpkeConfigId::from(13),
Expand All @@ -1787,6 +1778,23 @@ fn task_resp_serialization() {
),
HpkePrivateKey::new(b"unused".to_vec()),
))],
AggregatorTaskParameters::Leader {
aggregator_auth_token: AuthenticationToken::new_dap_auth_token_from_string(
"Y29sbGVjdG9yLWFiY2RlZjAw",
)
.unwrap(),
collector_auth_token: AuthenticationToken::new_dap_auth_token_from_string(
"Y29sbGVjdG9yLWFiY2RlZjAw",
)
.unwrap(),
collector_hpke_config: HpkeConfig::new(
HpkeConfigId::from(7),
HpkeKemId::X25519HkdfSha256,
HpkeKdfId::HkdfSha256,
HpkeAeadId::Aes128Gcm,
HpkePublicKey::from([0u8; 32].to_vec()),
),
},
)
.unwrap();
assert_tokens(
Expand Down

0 comments on commit 2d1e6c2

Please sign in to comment.