Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

updates for janus interop #104

Merged
merged 1 commit into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions src/aggregator_api_mock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::clients::aggregator_client::api_types::{
HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId, HpkePublicKey, JanusDuration, JanusTime,
HpkeAeadId, HpkeKdfId, HpkeKemId, HpkePublicKey, JanusDuration, JanusHpkeConfig, JanusTime,
TaskCreate, TaskIds, TaskMetrics, TaskResponse,
};
use fastrand::alphanumeric;
Expand Down Expand Up @@ -37,7 +37,8 @@ async fn post_task(_: &mut Conn, Json(task_create): Json<TaskCreate>) -> Json<Ta
pub fn task_response(task_create: TaskCreate) -> TaskResponse {
TaskResponse {
task_id: random(),
aggregator_endpoints: task_create.aggregator_endpoints,
leader_endpoint: task_create.leader_endpoint,
helper_endpoint: task_create.helper_endpoint,
query_type: task_create.query_type,
vdaf: task_create.vdaf,
role: task_create.role,
Expand All @@ -51,15 +52,12 @@ pub fn task_response(task_create: TaskCreate) -> TaskResponse {
collector_hpke_config: random_hpke_config(),
aggregator_auth_tokens: vec![],
collector_auth_tokens: vec![],
aggregator_hpke_configs: std::iter::repeat_with(random_hpke_config)
.take(5)
.map(|config| (*config.id(), config))
.collect(),
aggregator_hpke_configs: std::iter::repeat_with(random_hpke_config).take(5).collect(),
}
}

pub fn random_hpke_config() -> HpkeConfig {
HpkeConfig::new(
pub fn random_hpke_config() -> JanusHpkeConfig {
JanusHpkeConfig::new(
random(),
HpkeKemId::P256HkdfSha256,
HpkeKdfId::HkdfSha512,
Expand Down
218 changes: 123 additions & 95 deletions src/clients/aggregator_client/api_types.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
use crate::{
entity::{
task::{self, Histogram, Sum, Vdaf},
task::{
vdaf::{CountVec, Histogram, Sum, SumVec, Vdaf},
HpkeConfig,
},
NewTask,
},
handler::Error,
ApiConfig,
};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
pub use janus_messages::{
Duration as JanusDuration, HpkeAeadId, HpkeConfig, HpkeConfigId, HpkeConfigList, HpkeKdfId,
HpkeKemId, HpkePublicKey, Role, TaskId, Time as JanusTime,
Duration as JanusDuration, HpkeAeadId, HpkeConfig as JanusHpkeConfig, HpkeConfigId,
HpkeConfigList, HpkeKdfId, HpkeKemId, HpkePublicKey, Role, TaskId, Time as JanusTime,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use url::Url;

#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum VdafInstance {
Prio3Count,
Prio3Sum { bits: u8 },
Prio3Histogram { buckets: Vec<i32> },
Prio3Histogram { buckets: Vec<u64> },
Prio3CountVec { length: u64 },
Prio3SumVec { bits: u8, length: u64 },
}

impl From<VdafInstance> for Vdaf {
Expand All @@ -31,6 +35,13 @@ impl From<VdafInstance> for Vdaf {
VdafInstance::Prio3Histogram { buckets } => Self::Histogram(Histogram {
buckets: Some(buckets),
}),
VdafInstance::Prio3CountVec { length } => Self::CountVec(CountVec {
length: Some(length),
}),
VdafInstance::Prio3SumVec { bits, length } => Self::SumVec(SumVec {
length: Some(length),
bits: Some(bits),
}),
}
}
}
Expand All @@ -45,14 +56,21 @@ impl From<Vdaf> for VdafInstance {
Vdaf::Sum(Sum { bits }) => Self::Prio3Sum {
bits: bits.unwrap(),
},
Vdaf::CountVec(CountVec { length }) => Self::Prio3CountVec {
length: length.unwrap(),
},
Vdaf::SumVec(SumVec { length, bits }) => Self::Prio3SumVec {
bits: bits.unwrap(),
length: length.unwrap(),
},
Vdaf::Unrecognized => unreachable!(),
}
}
}

impl TryFrom<task::HpkeConfig> for HpkeConfig {
impl TryFrom<HpkeConfig> for JanusHpkeConfig {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(value: task::HpkeConfig) -> Result<Self, Self::Error> {
fn try_from(value: HpkeConfig) -> Result<Self, Self::Error> {
Ok(Self::new(
value.id.unwrap().into(),
value.kem_id.unwrap().try_into()?,
Expand All @@ -62,8 +80,8 @@ impl TryFrom<task::HpkeConfig> for HpkeConfig {
))
}
}
impl From<HpkeConfig> for task::HpkeConfig {
fn from(hpke_config: HpkeConfig) -> Self {
impl From<JanusHpkeConfig> for HpkeConfig {
fn from(hpke_config: JanusHpkeConfig) -> Self {
Self {
id: Some((*hpke_config.id()).into()),
kem_id: Some((*hpke_config.kem_id()) as u16),
Expand Down Expand Up @@ -111,30 +129,30 @@ impl From<Option<i64>> for QueryType {

#[derive(Serialize, Deserialize, Debug)]
pub struct TaskCreate {
pub aggregator_endpoints: Vec<Url>,
pub leader_endpoint: Url,
pub helper_endpoint: Url,
pub query_type: QueryType,
pub vdaf: VdafInstance,
pub role: Role,
pub max_batch_query_count: u64,
pub task_expiration: u64,
pub min_batch_size: u64,
pub time_precision: u64,
pub collector_hpke_config: HpkeConfig,
pub collector_hpke_config: JanusHpkeConfig,
}

impl TaskCreate {
pub fn build(new_task: NewTask, config: &ApiConfig) -> Result<Self, Error> {
Ok(Self {
aggregator_endpoints: if new_task.is_leader.unwrap() {
vec![
config.aggregator_dap_url.clone(),
new_task.partner_url.unwrap().parse()?,
]
leader_endpoint: if new_task.is_leader.unwrap() {
config.aggregator_dap_url.clone()
} else {
vec![
new_task.partner_url.unwrap().parse()?,
config.aggregator_dap_url.clone(),
]
new_task.partner_url.as_deref().unwrap().parse()?
},
helper_endpoint: if new_task.is_leader.unwrap() {
new_task.partner_url.as_deref().unwrap().parse()?
} else {
config.aggregator_dap_url.clone()
},
query_type: new_task.max_batch_size.into(),
vdaf: new_task.vdaf.unwrap().into(),
Expand All @@ -158,7 +176,8 @@ impl TaskCreate {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TaskResponse {
pub task_id: TaskId,
pub aggregator_endpoints: Vec<Url>,
pub leader_endpoint: Url,
pub helper_endpoint: Url,
pub query_type: QueryType,
pub vdaf: VdafInstance,
pub role: Role,
Expand All @@ -169,10 +188,10 @@ pub struct TaskResponse {
pub min_batch_size: u64,
pub time_precision: JanusDuration,
pub tolerable_clock_skew: JanusDuration,
pub collector_hpke_config: HpkeConfig,
pub collector_hpke_config: JanusHpkeConfig,
pub aggregator_auth_tokens: Vec<String>,
pub collector_auth_tokens: Vec<String>,
pub aggregator_hpke_configs: HashMap<HpkeConfigId, HpkeConfig>,
pub aggregator_hpke_configs: Vec<JanusHpkeConfig>,
}

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand All @@ -189,90 +208,99 @@ pub struct TaskMetrics {

#[cfg(test)]
mod test {
use serde_json::{from_value, json, to_value};

use super::{TaskCreate, TaskResponse};

const TASK_CREATE: &str = r#"{
"leader_endpoint": "https://example.com/",
"helper_endpoint": "https://example.net/",
"query_type": {
"FixedSize": {
"max_batch_size": 999
}
},
"vdaf": {
"Prio3CountVec": {
"length": 5
}
},
"role": "Leader",
"max_batch_query_count": 1,
"task_expiration": 18446744073709551615,
"min_batch_size": 100,
"time_precision": 3600,
"collector_hpke_config": {
"id": 7,
"kem_id": "X25519HkdfSha256",
"kdf_id": "HkdfSha256",
"aead_id": "Aes128Gcm",
"public_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
}
}"#;

#[test]
fn task_create_json_serialization() {
let task_create_from_janus_aggregator_api_tests = json!({
"aggregator_endpoints": [
"http://leader.endpoint/",
"http://helper.endpoint/"
],
"query_type": "TimeInterval",
"vdaf": "Prio3Count",
"role": "Leader",
"max_batch_query_count": 12,
"task_expiration": 12345,
"min_batch_size": 223,
"time_precision": 62,
"collector_hpke_config": {
"id": 199,
"kem_id": "X25519HkdfSha256",
"kdf_id": "HkdfSha256",
"aead_id": "Aes128Gcm",
"public_key": "p2J0ht1GtUa8XW67AKmYbfzU1L1etPlJiRIiRigzhEw"
}
});

let task_create: TaskCreate =
from_value(task_create_from_janus_aggregator_api_tests.clone()).unwrap();
let task_create: TaskCreate = serde_json::from_str(TASK_CREATE).unwrap();
assert_eq!(
to_value(&task_create).unwrap(),
task_create_from_janus_aggregator_api_tests
serde_json::to_string_pretty(&task_create).unwrap(),
TASK_CREATE
);
}

const TASK_RESPONSE: &str = r#"{
"task_id": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
"leader_endpoint": "https://example.com/",
"helper_endpoint": "https://example.net/",
"query_type": {
"FixedSize": {
"max_batch_size": 999
}
},
"vdaf": {
"Prio3CountVec": {
"length": 5
}
},
"role": "Leader",
"vdaf_verify_keys": [
"dmRhZiB2ZXJpZnkga2V5IQ"
],
"max_batch_query_count": 1,
"task_expiration": 9000000000,
"report_expiry_age": null,
"min_batch_size": 100,
"time_precision": 3600,
"tolerable_clock_skew": 60,
"collector_hpke_config": {
"id": 7,
"kem_id": "X25519HkdfSha256",
"kdf_id": "HkdfSha256",
"aead_id": "Aes128Gcm",
"public_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
},
"aggregator_auth_tokens": [
"YWdncmVnYXRvci0xMjM0NTY3OA"
],
"collector_auth_tokens": [
"Y29sbGVjdG9yLWFiY2RlZjAw"
],
"aggregator_hpke_configs": [
{
"id": 13,
"kem_id": "X25519HkdfSha256",
"kdf_id": "HkdfSha256",
"aead_id": "Aes128Gcm",
"public_key": "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
}
]
}"#;

#[test]
fn task_response_json_serialization() {
let task_response_from_janus_aggregator_api_tests = json!({
"task_id": "NGTX4o1JP4JLUCmM5Vcdl1Mcz41cOGgRnU1V0gU1Z_M",
"aggregator_endpoints": [
"http://leader.endpoint/",
"http://helper.endpoint/"
],
"query_type": "TimeInterval",
"vdaf": "Prio3Count",
"role": "Leader",
"vdaf_verify_keys": [
"Fvp4ZzHEbJOMGyTjG4Pctw"
],
"max_batch_query_count": 12,
"task_expiration": 12345,
"report_expiry_age": 1209600,
"min_batch_size": 223,
"time_precision": 62,
"tolerable_clock_skew": 60,
"collector_hpke_config": {
"id": 177,
"kem_id": "X25519HkdfSha256",
"kdf_id": "HkdfSha256",
"aead_id": "Aes128Gcm",
"public_key": "ifb-I8PBdIwuKcylg2_tRZ2_vf1XOWA-Jx5plLAn52Y"
},
"aggregator_auth_tokens": [
"MTlhMzBiZjE3NWMyN2FlZWFlYTI3NmVjMDIxZDM4MWQ"
],
"collector_auth_tokens": [
"YzMyYzU4YTc0ZjBmOGU5MjU0YWIzMjA0OGZkMTQyNTE"
],
"aggregator_hpke_configs": {
"43": {
"id": 43,
"kem_id": "X25519HkdfSha256",
"kdf_id": "HkdfSha256",
"aead_id": "Aes128Gcm",
"public_key": "j98s3TCKDutLGPFMULsWFgsQc-keIW8WNxp8aMKEJjk"
}
}
});
let task_response: TaskResponse = serde_json::from_str(TASK_RESPONSE).unwrap();

let task_response: TaskResponse =
from_value(task_response_from_janus_aggregator_api_tests.clone()).unwrap();
assert_eq!(
to_value(&task_response).unwrap(),
task_response_from_janus_aggregator_api_tests
serde_json::to_string_pretty(&task_response).unwrap(),
TASK_RESPONSE
);
}
}
9 changes: 4 additions & 5 deletions src/entity/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ use serde::{Deserialize, Serialize};
use time::OffsetDateTime;
use validator::{Validate, ValidationError};

mod vdaf;
pub use vdaf::{Histogram, Sum, Vdaf};

pub mod vdaf;
use vdaf::Vdaf;
mod url;
use self::url::Url;

Expand Down Expand Up @@ -152,8 +151,8 @@ pub fn build_task(mut task: NewTask, api_response: TaskResponse, account: &Accou
id: Set(api_response.task_id.to_string()),
account_id: Set(account.id),
name: Set(task.name.take().unwrap()),
leader_url: Set(api_response.aggregator_endpoints[0].clone().into()),
helper_url: Set(api_response.aggregator_endpoints[1].clone().into()),
leader_url: Set(api_response.leader_endpoint.clone().into()),
helper_url: Set(api_response.helper_endpoint.clone().into()),
vdaf: Set(Vdaf::from(api_response.vdaf)),
min_batch_size: Set(api_response.min_batch_size.try_into().unwrap()),
max_batch_size: Set(api_response.query_type.into()),
Expand Down
Loading
Loading