From 4c159d32b08ee446da282e26255415c10777cec0 Mon Sep 17 00:00:00 2001 From: David Cook Date: Tue, 30 Aug 2022 09:53:37 -0500 Subject: [PATCH] Add Valgrind-based cycle count benchmarks (#295) * Add cycle count benchmarks, using Valgrind * Clarify sizes in speed_tests benchmark output * Clippy fixes * Back out FFT benchmarks --- Cargo.lock | 8 ++ Cargo.toml | 6 ++ benches/cycle_counts.rs | 219 ++++++++++++++++++++++++++++++++++++++++ benches/speed_tests.rs | 52 +++++----- 4 files changed, 262 insertions(+), 23 deletions(-) create mode 100644 benches/cycle_counts.rs diff --git a/Cargo.lock b/Cargo.lock index 9cca9ab8..6e2027ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -499,6 +499,12 @@ dependencies = [ "serde", ] +[[package]] +name = "iai" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71a816c97c42258aa5834d07590b718b4c9a598944cd39a52dc25b351185d678" + [[package]] name = "indenter" version = "0.3.3" @@ -733,11 +739,13 @@ dependencies = [ "assert_matches", "base64", "byteorder", + "cfg-if", "cmac", "criterion", "ctr 0.9.1", "getrandom", "hex", + "iai", "itertools", "modinverse", "num-bigint", diff --git a/Cargo.toml b/Cargo.toml index d2336370..7dbf89d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ ring = { version = "0.16.20", optional = true } [dev-dependencies] assert_matches = "1.5.0" criterion = "0.3" +iai = "0.1" itertools = "0.10.3" modinverse = "0.1.0" num-bigint = "0.4.3" @@ -42,6 +43,7 @@ hex = { version = "0.4.3" , features = ["serde"] } # Enable test_vector module for test targets # https://github.com/rust-lang/cargo/issues/2911#issuecomment-749580481 prio = { path = ".", features = ["test-util"] } +cfg-if = "1.0.0" [features] default = ["crypto-dependencies"] @@ -57,6 +59,10 @@ members = [".", "binaries"] name = "speed_tests" harness = false +[[bench]] +name = "cycle_counts" +harness = false + [[example]] name = "sum" required-features = ["prio2"] diff --git a/benches/cycle_counts.rs b/benches/cycle_counts.rs new file mode 100644 index 00000000..243b90e9 --- /dev/null +++ b/benches/cycle_counts.rs @@ -0,0 +1,219 @@ +#![cfg_attr(windows, allow(dead_code))] + +use cfg_if::cfg_if; +use iai::black_box; +use prio::{ + field::{random_vector, Field128, Field64}, + vdaf::{ + prio3::{Prio3, Prio3InputShare}, + Client, + }, +}; +#[cfg(feature = "prio2")] +use prio::{ + field::{FieldElement, FieldPrio2}, + server::VerificationMessage, +}; + +fn prng(size: usize) -> Vec { + random_vector(size).unwrap() +} + +fn prng_16() -> Vec { + prng(16) +} + +fn prng_256() -> Vec { + prng(256) +} + +fn prng_1024() -> Vec { + prng(1024) +} + +fn prng_4096() -> Vec { + prng(4096) +} + +#[cfg(feature = "prio2")] +const PRIO2_PUBKEY1: &str = + "BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ="; +#[cfg(feature = "prio2")] +const PRIO2_PUBKEY2: &str = + "BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LE="; + +#[cfg(feature = "prio2")] +fn prio2_prove(size: usize) -> Vec { + use prio::{benchmarked::benchmarked_v2_prove, client::Client, encrypt::PublicKey}; + + let input = vec![FieldPrio2::zero(); size]; + let pk1 = PublicKey::from_base64(PRIO2_PUBKEY1).unwrap(); + let pk2 = PublicKey::from_base64(PRIO2_PUBKEY2).unwrap(); + let mut client: Client = Client::new(input.len(), pk1, pk2).unwrap(); + benchmarked_v2_prove(&black_box(input), &mut client) +} + +#[cfg(feature = "prio2")] +fn prio2_prove_10() -> Vec { + prio2_prove(10) +} + +#[cfg(feature = "prio2")] +fn prio2_prove_100() -> Vec { + prio2_prove(100) +} + +#[cfg(feature = "prio2")] +fn prio2_prove_1000() -> Vec { + prio2_prove(1_000) +} + +#[cfg(feature = "prio2")] +fn prio2_prove_and_verify(size: usize) -> VerificationMessage { + use prio::{ + benchmarked::benchmarked_v2_prove, + client::Client, + encrypt::PublicKey, + server::{generate_verification_message, ValidationMemory}, + }; + + let input = vec![FieldPrio2::zero(); size]; + let pk1 = PublicKey::from_base64(PRIO2_PUBKEY1).unwrap(); + let pk2 = PublicKey::from_base64(PRIO2_PUBKEY2).unwrap(); + let mut client: Client = Client::new(input.len(), pk1, pk2).unwrap(); + let input_and_proof = benchmarked_v2_prove(&input, &mut client); + let mut validator = ValidationMemory::new(input.len()); + let eval_at = random_vector(1).unwrap()[0]; + generate_verification_message( + input.len(), + eval_at, + &black_box(input_and_proof), + true, + &mut validator, + ) + .unwrap() +} + +#[cfg(feature = "prio2")] +fn prio2_prove_and_verify_10() -> VerificationMessage { + prio2_prove_and_verify(10) +} + +#[cfg(feature = "prio2")] +fn prio2_prove_and_verify_100() -> VerificationMessage { + prio2_prove_and_verify(100) +} + +#[cfg(feature = "prio2")] +fn prio2_prove_and_verify_1000() -> VerificationMessage { + prio2_prove_and_verify(1_000) +} + +fn prio3_client_count() -> Vec> { + let prio3 = Prio3::new_aes128_count(2).unwrap(); + let measurement = 1; + prio3.shard(&black_box(measurement)).unwrap().1 +} + +fn prio3_client_histogram_11() -> Vec> { + let buckets: Vec = (1..10).collect(); + let prio3 = Prio3::new_aes128_histogram(2, &buckets).unwrap(); + let measurement = 17; + prio3.shard(&black_box(measurement)).unwrap().1 +} + +fn prio3_client_sum_32() -> Vec> { + let prio3 = Prio3::new_aes128_sum(2, 16).unwrap(); + let measurement = 1337; + prio3.shard(&black_box(measurement)).unwrap().1 +} + +fn prio3_client_count_vec_1000() -> Vec> { + let len = 1000; + let prio3 = Prio3::new_aes128_count_vec(2, len).unwrap(); + let measurement = vec![0; len]; + prio3.shard(&black_box(measurement)).unwrap().1 +} + +#[cfg(feature = "multithreaded")] +fn prio3_client_count_vec_multithreaded_1000() -> Vec> { + let len = 1000; + let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, len).unwrap(); + let measurement = vec![0; len]; + prio3.shard(&black_box(measurement)).unwrap().1 +} + +cfg_if! { + if #[cfg(windows)] { + fn main() { + eprintln!("Cycle count benchmarks are not supported on Windows."); + } + } + else if #[cfg(feature = "prio2")] { + cfg_if! { + if #[cfg(feature = "multithreaded")] { + iai::main!( + prng_16, + prng_256, + prng_1024, + prng_4096, + prio2_prove_10, + prio2_prove_100, + prio2_prove_1000, + prio2_prove_and_verify_10, + prio2_prove_and_verify_100, + prio2_prove_and_verify_1000, + prio3_client_count, + prio3_client_histogram_11, + prio3_client_sum_32, + prio3_client_count_vec_1000, + prio3_client_count_vec_multithreaded_1000, + ); + } else { + iai::main!( + prng_16, + prng_256, + prng_1024, + prng_4096, + prio2_prove_10, + prio2_prove_100, + prio2_prove_1000, + prio2_prove_and_verify_10, + prio2_prove_and_verify_100, + prio2_prove_and_verify_1000, + prio3_client_count, + prio3_client_histogram_11, + prio3_client_sum_32, + prio3_client_count_vec_1000, + ); + } + } + } else { + cfg_if! { + if #[cfg(feature = "multithreaded")] { + iai::main!( + prng_16, + prng_256, + prng_1024, + prng_4096, + prio3_client_count, + prio3_client_histogram_11, + prio3_client_sum_32, + prio3_client_count_vec_1000, + prio3_client_count_vec_multithreaded_1000, + ); + } else { + iai::main!( + prng_16, + prng_256, + prng_1024, + prng_4096, + prio3_client_count, + prio3_client_histogram_11, + prio3_client_sum_32, + prio3_client_count_vec_1000, + ); + } + } + } +} diff --git a/benches/speed_tests.rs b/benches/speed_tests.rs index 7dd1f866..0a6d679a 100644 --- a/benches/speed_tests.rs +++ b/benches/speed_tests.rs @@ -101,7 +101,7 @@ pub fn count_vec(c: &mut Criterion) { benchmarked_v2_prove(&input, &mut client).len() ); - c.bench_function(&format!("prio2 prove, size={}", *size), |b| { + c.bench_function(&format!("prio2 prove, input size={}", *size), |b| { b.iter(|| { benchmarked_v2_prove(&input, &mut client); }) @@ -111,7 +111,7 @@ pub fn count_vec(c: &mut Criterion) { let mut validator: ValidationMemory = ValidationMemory::new(input.len()); let eval_at = random_vector(1).unwrap()[0]; - c.bench_function(&format!("prio2 query, size={}", *size), |b| { + c.bench_function(&format!("prio2 query, input size={}", *size), |b| { b.iter(|| { generate_verification_message( input.len(), @@ -133,21 +133,27 @@ pub fn count_vec(c: &mut Criterion) { println!("prio3 countvec proof size={}\n", proof.len()); - c.bench_function(&format!("prio3 countvec prove, size={}", *size), |b| { - b.iter(|| { - let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap(); - count_vec.prove(&input, &prove_rand, &joint_rand).unwrap(); - }) - }); + c.bench_function( + &format!("prio3 countvec prove, input size={}", *size), + |b| { + b.iter(|| { + let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap(); + count_vec.prove(&input, &prove_rand, &joint_rand).unwrap(); + }) + }, + ); - c.bench_function(&format!("prio3 countvec query, size={}", *size), |b| { - b.iter(|| { - let query_rand = random_vector(count_vec.query_rand_len()).unwrap(); - count_vec - .query(&input, &proof, &query_rand, &joint_rand, 1) - .unwrap(); - }) - }); + c.bench_function( + &format!("prio3 countvec query, input size={}", *size), + |b| { + b.iter(|| { + let query_rand = random_vector(count_vec.query_rand_len()).unwrap(); + count_vec + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + }) + }, + ); #[cfg(feature = "multithreaded")] { @@ -155,7 +161,7 @@ pub fn count_vec(c: &mut Criterion) { CountVec::new(*size); c.bench_function( - &format!("prio3 countvec multithreaded prove, size={}", *size), + &format!("prio3 countvec multithreaded prove, input size={}", *size), |b| { b.iter(|| { let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap(); @@ -165,7 +171,7 @@ pub fn count_vec(c: &mut Criterion) { ); c.bench_function( - &format!("prio3 countvec multithreaded query, size={}", *size), + &format!("prio3 countvec multithreaded query, input size={}", *size), |b| { b.iter(|| { let query_rand = random_vector(count_vec.query_rand_len()).unwrap(); @@ -186,7 +192,7 @@ pub fn prio3_client(c: &mut Criterion) { let prio3 = Prio3::new_aes128_count(num_shares).unwrap(); let measurement = 1; println!( - "prio3 count size = {}", + "prio3 count share size = {}", prio3_input_share_size(&prio3.shard(&measurement).unwrap().1) ); c.bench_function("prio3 count", |b| { @@ -199,7 +205,7 @@ pub fn prio3_client(c: &mut Criterion) { let prio3 = Prio3::new_aes128_histogram(num_shares, &buckets).unwrap(); let measurement = 17; println!( - "prio3 histogram ({} buckets) size = {}", + "prio3 histogram ({} buckets) share size = {}", buckets.len() + 1, prio3_input_share_size(&prio3.shard(&measurement).unwrap().1) ); @@ -216,7 +222,7 @@ pub fn prio3_client(c: &mut Criterion) { let prio3 = Prio3::new_aes128_sum(num_shares, bits).unwrap(); let measurement = 1337; println!( - "prio3 sum ({} bits) size = {}", + "prio3 sum ({} bits) share size = {}", bits, prio3_input_share_size(&prio3.shard(&measurement).unwrap().1) ); @@ -230,7 +236,7 @@ pub fn prio3_client(c: &mut Criterion) { let prio3 = Prio3::new_aes128_count_vec(num_shares, len).unwrap(); let measurement = vec![0; len]; println!( - "prio3 countvec ({} len) size = {}", + "prio3 countvec ({} len) share size = {}", len, prio3_input_share_size(&prio3.shard(&measurement).unwrap().1) ); @@ -245,7 +251,7 @@ pub fn prio3_client(c: &mut Criterion) { let prio3 = Prio3::new_aes128_count_vec_multithreaded(num_shares, len).unwrap(); let measurement = vec![0; len]; println!( - "prio3 countvec multithreaded ({} len) size = {}", + "prio3 countvec multithreaded ({} len) share size = {}", len, prio3_input_share_size(&prio3.shard(&measurement).unwrap().1) );