From 2203c3a567374994df349563d0ddb21be2d5b396 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Tue, 16 May 2023 13:30:26 -0700 Subject: [PATCH] updates for janus interop * add sumvec and sumcount vdaf types * make histogram buckets u64 not i32 * use leader_endpoint and helper_endpoint instead of `Vec` * use a `Vec` instead of `HashMap` * update test fixtures, output directly from janus aggregator api tests --- src/aggregator_api_mock.rs | 14 +- src/clients/aggregator_client/api_types.rs | 218 ++++++++++++--------- src/entity/task.rs | 9 +- src/entity/task/vdaf.rs | 29 ++- tests/harness/fixtures.rs | 2 +- tests/tasks.rs | 2 +- 6 files changed, 161 insertions(+), 113 deletions(-) diff --git a/src/aggregator_api_mock.rs b/src/aggregator_api_mock.rs index 1c538568..f9f2ee83 100644 --- a/src/aggregator_api_mock.rs +++ b/src/aggregator_api_mock.rs @@ -1,5 +1,5 @@ use crate::clients::aggregator_client::api_types::{ - HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId, HpkePublicKey, JanusDuration, JanusTime, + HpkeAeadId, HpkeKdfId, HpkeKemId, HpkePublicKey, JanusDuration, JanusHpkeConfig, JanusTime, TaskCreate, TaskIds, TaskMetrics, TaskResponse, }; use fastrand::alphanumeric; @@ -37,7 +37,8 @@ async fn post_task(_: &mut Conn, Json(task_create): Json) -> Json TaskResponse { TaskResponse { task_id: random(), - aggregator_endpoints: task_create.aggregator_endpoints, + leader_endpoint: task_create.leader_endpoint, + helper_endpoint: task_create.helper_endpoint, query_type: task_create.query_type, vdaf: task_create.vdaf, role: task_create.role, @@ -51,15 +52,12 @@ pub fn task_response(task_create: TaskCreate) -> TaskResponse { collector_hpke_config: random_hpke_config(), aggregator_auth_tokens: vec![], collector_auth_tokens: vec![], - aggregator_hpke_configs: std::iter::repeat_with(random_hpke_config) - .take(5) - .map(|config| (*config.id(), config)) - .collect(), + aggregator_hpke_configs: std::iter::repeat_with(random_hpke_config).take(5).collect(), } } -pub fn random_hpke_config() -> HpkeConfig { - HpkeConfig::new( +pub fn random_hpke_config() -> JanusHpkeConfig { + JanusHpkeConfig::new( random(), HpkeKemId::P256HkdfSha256, HpkeKdfId::HkdfSha512, diff --git a/src/clients/aggregator_client/api_types.rs b/src/clients/aggregator_client/api_types.rs index 9bf83e5e..3f05a8f6 100644 --- a/src/clients/aggregator_client/api_types.rs +++ b/src/clients/aggregator_client/api_types.rs @@ -1,6 +1,9 @@ use crate::{ entity::{ - task::{self, Histogram, Sum, Vdaf}, + task::{ + vdaf::{CountVec, Histogram, Sum, SumVec, Vdaf}, + HpkeConfig, + }, NewTask, }, handler::Error, @@ -8,11 +11,10 @@ use crate::{ }; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; pub use janus_messages::{ - Duration as JanusDuration, HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeConfigList, HpkeKdfId, - HpkeKemId, HpkePublicKey, Role, TaskId, Time as JanusTime, + Duration as JanusDuration, HpkeAeadId, HpkeConfig as JanusHpkeConfig, HpkeConfigId, + HpkeConfigList, HpkeKdfId, HpkeKemId, HpkePublicKey, Role, TaskId, Time as JanusTime, }; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use url::Url; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -20,7 +22,9 @@ use url::Url; pub enum VdafInstance { Prio3Count, Prio3Sum { bits: u8 }, - Prio3Histogram { buckets: Vec }, + Prio3Histogram { buckets: Vec }, + Prio3CountVec { length: u64 }, + Prio3SumVec { bits: u8, length: u64 }, } impl From for Vdaf { @@ -31,6 +35,13 @@ impl From for Vdaf { VdafInstance::Prio3Histogram { buckets } => Self::Histogram(Histogram { buckets: Some(buckets), }), + VdafInstance::Prio3CountVec { length } => Self::CountVec(CountVec { + length: Some(length), + }), + VdafInstance::Prio3SumVec { bits, length } => Self::SumVec(SumVec { + length: Some(length), + bits: Some(bits), + }), } } } @@ -45,14 +56,21 @@ impl From for VdafInstance { Vdaf::Sum(Sum { bits }) => Self::Prio3Sum { bits: bits.unwrap(), }, + Vdaf::CountVec(CountVec { length }) => Self::Prio3CountVec { + length: length.unwrap(), + }, + Vdaf::SumVec(SumVec { length, bits }) => Self::Prio3SumVec { + bits: bits.unwrap(), + length: length.unwrap(), + }, Vdaf::Unrecognized => unreachable!(), } } } -impl TryFrom for HpkeConfig { +impl TryFrom for JanusHpkeConfig { type Error = Box; - fn try_from(value: task::HpkeConfig) -> Result { + fn try_from(value: HpkeConfig) -> Result { Ok(Self::new( value.id.unwrap().into(), value.kem_id.unwrap().try_into()?, @@ -62,8 +80,8 @@ impl TryFrom for HpkeConfig { )) } } -impl From for task::HpkeConfig { - fn from(hpke_config: HpkeConfig) -> Self { +impl From for HpkeConfig { + fn from(hpke_config: JanusHpkeConfig) -> Self { Self { id: Some((*hpke_config.id()).into()), kem_id: Some((*hpke_config.kem_id()) as u16), @@ -111,7 +129,8 @@ impl From> for QueryType { #[derive(Serialize, Deserialize, Debug)] pub struct TaskCreate { - pub aggregator_endpoints: Vec, + pub leader_endpoint: Url, + pub helper_endpoint: Url, pub query_type: QueryType, pub vdaf: VdafInstance, pub role: Role, @@ -119,22 +138,21 @@ pub struct TaskCreate { pub task_expiration: u64, pub min_batch_size: u64, pub time_precision: u64, - pub collector_hpke_config: HpkeConfig, + pub collector_hpke_config: JanusHpkeConfig, } impl TaskCreate { pub fn build(new_task: NewTask, config: &ApiConfig) -> Result { Ok(Self { - aggregator_endpoints: if new_task.is_leader.unwrap() { - vec![ - config.aggregator_dap_url.clone(), - new_task.partner_url.unwrap().parse()?, - ] + leader_endpoint: if new_task.is_leader.unwrap() { + config.aggregator_dap_url.clone() } else { - vec![ - new_task.partner_url.unwrap().parse()?, - config.aggregator_dap_url.clone(), - ] + new_task.partner_url.as_deref().unwrap().parse()? + }, + helper_endpoint: if new_task.is_leader.unwrap() { + new_task.partner_url.as_deref().unwrap().parse()? + } else { + config.aggregator_dap_url.clone() }, query_type: new_task.max_batch_size.into(), vdaf: new_task.vdaf.unwrap().into(), @@ -158,7 +176,8 @@ impl TaskCreate { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TaskResponse { pub task_id: TaskId, - pub aggregator_endpoints: Vec, + pub leader_endpoint: Url, + pub helper_endpoint: Url, pub query_type: QueryType, pub vdaf: VdafInstance, pub role: Role, @@ -169,10 +188,10 @@ pub struct TaskResponse { pub min_batch_size: u64, pub time_precision: JanusDuration, pub tolerable_clock_skew: JanusDuration, - pub collector_hpke_config: HpkeConfig, + pub collector_hpke_config: JanusHpkeConfig, pub aggregator_auth_tokens: Vec, pub collector_auth_tokens: Vec, - pub aggregator_hpke_configs: HashMap, + pub aggregator_hpke_configs: Vec, } #[derive(Serialize, Deserialize, Debug, Clone)] @@ -189,90 +208,99 @@ pub struct TaskMetrics { #[cfg(test)] mod test { - use serde_json::{from_value, json, to_value}; - use super::{TaskCreate, TaskResponse}; + const TASK_CREATE: &str = r#"{ + "leader_endpoint": "https://example.com/", + "helper_endpoint": "https://example.net/", + "query_type": { + "FixedSize": { + "max_batch_size": 999 + } + }, + "vdaf": { + "Prio3CountVec": { + "length": 5 + } + }, + "role": "Leader", + "max_batch_query_count": 1, + "task_expiration": 18446744073709551615, + "min_batch_size": 100, + "time_precision": 3600, + "collector_hpke_config": { + "id": 7, + "kem_id": "X25519HkdfSha256", + "kdf_id": "HkdfSha256", + "aead_id": "Aes128Gcm", + "public_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + } +}"#; + #[test] fn task_create_json_serialization() { - let task_create_from_janus_aggregator_api_tests = json!({ - "aggregator_endpoints": [ - "http://leader.endpoint/", - "http://helper.endpoint/" - ], - "query_type": "TimeInterval", - "vdaf": "Prio3Count", - "role": "Leader", - "max_batch_query_count": 12, - "task_expiration": 12345, - "min_batch_size": 223, - "time_precision": 62, - "collector_hpke_config": { - "id": 199, - "kem_id": "X25519HkdfSha256", - "kdf_id": "HkdfSha256", - "aead_id": "Aes128Gcm", - "public_key": "p2J0ht1GtUa8XW67AKmYbfzU1L1etPlJiRIiRigzhEw" - } - }); - - let task_create: TaskCreate = - from_value(task_create_from_janus_aggregator_api_tests.clone()).unwrap(); + let task_create: TaskCreate = serde_json::from_str(TASK_CREATE).unwrap(); assert_eq!( - to_value(&task_create).unwrap(), - task_create_from_janus_aggregator_api_tests + serde_json::to_string_pretty(&task_create).unwrap(), + TASK_CREATE ); } + const TASK_RESPONSE: &str = r#"{ + "task_id": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + "leader_endpoint": "https://example.com/", + "helper_endpoint": "https://example.net/", + "query_type": { + "FixedSize": { + "max_batch_size": 999 + } + }, + "vdaf": { + "Prio3CountVec": { + "length": 5 + } + }, + "role": "Leader", + "vdaf_verify_keys": [ + "dmRhZiB2ZXJpZnkga2V5IQ" + ], + "max_batch_query_count": 1, + "task_expiration": 9000000000, + "report_expiry_age": null, + "min_batch_size": 100, + "time_precision": 3600, + "tolerable_clock_skew": 60, + "collector_hpke_config": { + "id": 7, + "kem_id": "X25519HkdfSha256", + "kdf_id": "HkdfSha256", + "aead_id": "Aes128Gcm", + "public_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + }, + "aggregator_auth_tokens": [ + "YWdncmVnYXRvci0xMjM0NTY3OA" + ], + "collector_auth_tokens": [ + "Y29sbGVjdG9yLWFiY2RlZjAw" + ], + "aggregator_hpke_configs": [ + { + "id": 13, + "kem_id": "X25519HkdfSha256", + "kdf_id": "HkdfSha256", + "aead_id": "Aes128Gcm", + "public_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA" + } + ] +}"#; + #[test] fn task_response_json_serialization() { - let task_response_from_janus_aggregator_api_tests = json!({ - "task_id": "NGTX4o1JP4JLUCmM5Vcdl1Mcz41cOGgRnU1V0gU1Z_M", - "aggregator_endpoints": [ - "http://leader.endpoint/", - "http://helper.endpoint/" - ], - "query_type": "TimeInterval", - "vdaf": "Prio3Count", - "role": "Leader", - "vdaf_verify_keys": [ - "Fvp4ZzHEbJOMGyTjG4Pctw" - ], - "max_batch_query_count": 12, - "task_expiration": 12345, - "report_expiry_age": 1209600, - "min_batch_size": 223, - "time_precision": 62, - "tolerable_clock_skew": 60, - "collector_hpke_config": { - "id": 177, - "kem_id": "X25519HkdfSha256", - "kdf_id": "HkdfSha256", - "aead_id": "Aes128Gcm", - "public_key": "ifb-I8PBdIwuKcylg2_tRZ2_vf1XOWA-Jx5plLAn52Y" - }, - "aggregator_auth_tokens": [ - "MTlhMzBiZjE3NWMyN2FlZWFlYTI3NmVjMDIxZDM4MWQ" - ], - "collector_auth_tokens": [ - "YzMyYzU4YTc0ZjBmOGU5MjU0YWIzMjA0OGZkMTQyNTE" - ], - "aggregator_hpke_configs": { - "43": { - "id": 43, - "kem_id": "X25519HkdfSha256", - "kdf_id": "HkdfSha256", - "aead_id": "Aes128Gcm", - "public_key": "j98s3TCKDutLGPFMULsWFgsQc-keIW8WNxp8aMKEJjk" - } - } - }); + let task_response: TaskResponse = serde_json::from_str(TASK_RESPONSE).unwrap(); - let task_response: TaskResponse = - from_value(task_response_from_janus_aggregator_api_tests.clone()).unwrap(); assert_eq!( - to_value(&task_response).unwrap(), - task_response_from_janus_aggregator_api_tests + serde_json::to_string_pretty(&task_response).unwrap(), + TASK_RESPONSE ); } } diff --git a/src/entity/task.rs b/src/entity/task.rs index db124ee2..3f2fdb28 100644 --- a/src/entity/task.rs +++ b/src/entity/task.rs @@ -8,9 +8,8 @@ use serde::{Deserialize, Serialize}; use time::OffsetDateTime; use validator::{Validate, ValidationError}; -mod vdaf; -pub use vdaf::{Histogram, Sum, Vdaf}; - +pub mod vdaf; +use vdaf::Vdaf; mod url; use self::url::Url; @@ -152,8 +151,8 @@ pub fn build_task(mut task: NewTask, api_response: TaskResponse, account: &Accou id: Set(api_response.task_id.to_string()), account_id: Set(account.id), name: Set(task.name.take().unwrap()), - leader_url: Set(api_response.aggregator_endpoints[0].clone().into()), - helper_url: Set(api_response.aggregator_endpoints[1].clone().into()), + leader_url: Set(api_response.leader_endpoint.clone().into()), + helper_url: Set(api_response.helper_endpoint.clone().into()), vdaf: Set(Vdaf::from(api_response.vdaf)), min_batch_size: Set(api_response.min_batch_size.try_into().unwrap()), max_batch_size: Set(api_response.query_type.into()), diff --git a/src/entity/task/vdaf.rs b/src/entity/task/vdaf.rs index 7b07ba56..b7e3da79 100644 --- a/src/entity/task/vdaf.rs +++ b/src/entity/task/vdaf.rs @@ -5,10 +5,10 @@ use validator::{Validate, ValidationError, ValidationErrors}; #[derive(Serialize, Deserialize, Validate, Debug, Clone, Eq, PartialEq)] pub struct Histogram { #[validate(required, custom = "strictly_increasing")] - pub buckets: Option>, + pub buckets: Option>, } -fn strictly_increasing(buckets: &Vec) -> Result<(), ValidationError> { +fn strictly_increasing(buckets: &Vec) -> Result<(), ValidationError> { let mut last_bucket = None; for bucket in buckets { let bucket = *bucket; @@ -35,6 +35,21 @@ pub struct Sum { pub bits: Option, } +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Copy, Eq, PartialEq)] +pub struct CountVec { + #[validate(required)] + pub length: Option, +} + +#[derive(Serialize, Deserialize, Validate, Debug, Clone, Copy, Eq, PartialEq)] +pub struct SumVec { + #[validate(required)] + pub bits: Option, + + #[validate(required)] + pub length: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] #[serde(rename_all = "snake_case", tag = "type")] pub enum Vdaf { @@ -45,7 +60,13 @@ pub enum Vdaf { Histogram(Histogram), #[serde(rename = "sum")] - Sum(Sum), // 128 is ceiling + Sum(Sum), + + #[serde(rename = "count_vec")] + CountVec(CountVec), + + #[serde(rename = "sum_vec")] + SumVec(SumVec), #[serde(other)] Unrecognized, @@ -59,6 +80,8 @@ impl Validate for Vdaf { Vdaf::Count => Ok(()), Vdaf::Histogram(h) => h.validate(), Vdaf::Sum(s) => s.validate(), + Vdaf::SumVec(sv) => sv.validate(), + Vdaf::CountVec(cv) => cv.validate(), Vdaf::Unrecognized => { let mut errors = ValidationErrors::new(); errors.add("type", ValidationError::new("unknown")); diff --git a/tests/harness/fixtures.rs b/tests/harness/fixtures.rs index 72d77b92..9932b6d0 100644 --- a/tests/harness/fixtures.rs +++ b/tests/harness/fixtures.rs @@ -77,7 +77,7 @@ pub async fn task(app: &DivviupApi, account: &Account) -> Task { let new_task = NewTask { name: Some(random_name()), partner_url: Some("https://dap.clodflair.test".into()), - vdaf: Some(task::Vdaf::Count), + vdaf: Some(task::vdaf::Vdaf::Count), min_batch_size: Some(500), max_batch_size: Some(10000), is_leader: Some(true), diff --git a/tests/tasks.rs b/tests/tasks.rs index 5537cfbd..deb938ac 100644 --- a/tests/tasks.rs +++ b/tests/tasks.rs @@ -86,7 +86,7 @@ mod index { mod create { use divviup_api::{ aggregator_api_mock::random_hpke_config, - entity::task::{HpkeConfig, Vdaf}, + entity::task::{vdaf::Vdaf, HpkeConfig}, }; use super::{test, *};