Skip to content

Commit

Permalink
Add Valgrind-based cycle count benchmarks (#295)
Browse files Browse the repository at this point in the history
* Add cycle count benchmarks, using Valgrind

* Clarify sizes in speed_tests benchmark output

* Clippy fixes

* Back out FFT benchmarks
  • Loading branch information
divergentdave committed Aug 30, 2022
1 parent 5b6cba8 commit 4c159d3
Show file tree
Hide file tree
Showing 4 changed files with 262 additions and 23 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ring = { version = "0.16.20", optional = true }
[dev-dependencies]
assert_matches = "1.5.0"
criterion = "0.3"
iai = "0.1"
itertools = "0.10.3"
modinverse = "0.1.0"
num-bigint = "0.4.3"
Expand All @@ -42,6 +43,7 @@ hex = { version = "0.4.3" , features = ["serde"] }
# Enable test_vector module for test targets
# https://github.com/rust-lang/cargo/issues/2911#issuecomment-749580481
prio = { path = ".", features = ["test-util"] }
cfg-if = "1.0.0"

[features]
default = ["crypto-dependencies"]
Expand All @@ -57,6 +59,10 @@ members = [".", "binaries"]
name = "speed_tests"
harness = false

[[bench]]
name = "cycle_counts"
harness = false

[[example]]
name = "sum"
required-features = ["prio2"]
Expand Down
219 changes: 219 additions & 0 deletions benches/cycle_counts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
#![cfg_attr(windows, allow(dead_code))]

use cfg_if::cfg_if;
use iai::black_box;
use prio::{
field::{random_vector, Field128, Field64},
vdaf::{
prio3::{Prio3, Prio3InputShare},
Client,
},
};
#[cfg(feature = "prio2")]
use prio::{
field::{FieldElement, FieldPrio2},
server::VerificationMessage,
};

fn prng(size: usize) -> Vec<Field128> {
random_vector(size).unwrap()
}

fn prng_16() -> Vec<Field128> {
prng(16)
}

fn prng_256() -> Vec<Field128> {
prng(256)
}

fn prng_1024() -> Vec<Field128> {
prng(1024)
}

fn prng_4096() -> Vec<Field128> {
prng(4096)
}

#[cfg(feature = "prio2")]
const PRIO2_PUBKEY1: &str =
"BIl6j+J6dYttxALdjISDv6ZI4/VWVEhUzaS05LgrsfswmbLOgNt9HUC2E0w+9RqZx3XMkdEHBHfNuCSMpOwofVQ=";
#[cfg(feature = "prio2")]
const PRIO2_PUBKEY2: &str =
"BNNOqoU54GPo+1gTPv+hCgA9U2ZCKd76yOMrWa1xTWgeb4LhFLMQIQoRwDVaW64g/WTdcxT4rDULoycUNFB60LE=";

#[cfg(feature = "prio2")]
fn prio2_prove(size: usize) -> Vec<FieldPrio2> {
use prio::{benchmarked::benchmarked_v2_prove, client::Client, encrypt::PublicKey};

let input = vec![FieldPrio2::zero(); size];
let pk1 = PublicKey::from_base64(PRIO2_PUBKEY1).unwrap();
let pk2 = PublicKey::from_base64(PRIO2_PUBKEY2).unwrap();
let mut client: Client<FieldPrio2> = Client::new(input.len(), pk1, pk2).unwrap();
benchmarked_v2_prove(&black_box(input), &mut client)
}

#[cfg(feature = "prio2")]
fn prio2_prove_10() -> Vec<FieldPrio2> {
prio2_prove(10)
}

#[cfg(feature = "prio2")]
fn prio2_prove_100() -> Vec<FieldPrio2> {
prio2_prove(100)
}

#[cfg(feature = "prio2")]
fn prio2_prove_1000() -> Vec<FieldPrio2> {
prio2_prove(1_000)
}

#[cfg(feature = "prio2")]
fn prio2_prove_and_verify(size: usize) -> VerificationMessage<FieldPrio2> {
use prio::{
benchmarked::benchmarked_v2_prove,
client::Client,
encrypt::PublicKey,
server::{generate_verification_message, ValidationMemory},
};

let input = vec![FieldPrio2::zero(); size];
let pk1 = PublicKey::from_base64(PRIO2_PUBKEY1).unwrap();
let pk2 = PublicKey::from_base64(PRIO2_PUBKEY2).unwrap();
let mut client: Client<FieldPrio2> = Client::new(input.len(), pk1, pk2).unwrap();
let input_and_proof = benchmarked_v2_prove(&input, &mut client);
let mut validator = ValidationMemory::new(input.len());
let eval_at = random_vector(1).unwrap()[0];
generate_verification_message(
input.len(),
eval_at,
&black_box(input_and_proof),
true,
&mut validator,
)
.unwrap()
}

#[cfg(feature = "prio2")]
fn prio2_prove_and_verify_10() -> VerificationMessage<FieldPrio2> {
prio2_prove_and_verify(10)
}

#[cfg(feature = "prio2")]
fn prio2_prove_and_verify_100() -> VerificationMessage<FieldPrio2> {
prio2_prove_and_verify(100)
}

#[cfg(feature = "prio2")]
fn prio2_prove_and_verify_1000() -> VerificationMessage<FieldPrio2> {
prio2_prove_and_verify(1_000)
}

fn prio3_client_count() -> Vec<Prio3InputShare<Field64, 16>> {
let prio3 = Prio3::new_aes128_count(2).unwrap();
let measurement = 1;
prio3.shard(&black_box(measurement)).unwrap().1
}

fn prio3_client_histogram_11() -> Vec<Prio3InputShare<Field128, 16>> {
let buckets: Vec<u64> = (1..10).collect();
let prio3 = Prio3::new_aes128_histogram(2, &buckets).unwrap();
let measurement = 17;
prio3.shard(&black_box(measurement)).unwrap().1
}

fn prio3_client_sum_32() -> Vec<Prio3InputShare<Field128, 16>> {
let prio3 = Prio3::new_aes128_sum(2, 16).unwrap();
let measurement = 1337;
prio3.shard(&black_box(measurement)).unwrap().1
}

fn prio3_client_count_vec_1000() -> Vec<Prio3InputShare<Field128, 16>> {
let len = 1000;
let prio3 = Prio3::new_aes128_count_vec(2, len).unwrap();
let measurement = vec![0; len];
prio3.shard(&black_box(measurement)).unwrap().1
}

#[cfg(feature = "multithreaded")]
fn prio3_client_count_vec_multithreaded_1000() -> Vec<Prio3InputShare<Field128, 16>> {
let len = 1000;
let prio3 = Prio3::new_aes128_count_vec_multithreaded(2, len).unwrap();
let measurement = vec![0; len];
prio3.shard(&black_box(measurement)).unwrap().1
}

cfg_if! {
if #[cfg(windows)] {
fn main() {
eprintln!("Cycle count benchmarks are not supported on Windows.");
}
}
else if #[cfg(feature = "prio2")] {
cfg_if! {
if #[cfg(feature = "multithreaded")] {
iai::main!(
prng_16,
prng_256,
prng_1024,
prng_4096,
prio2_prove_10,
prio2_prove_100,
prio2_prove_1000,
prio2_prove_and_verify_10,
prio2_prove_and_verify_100,
prio2_prove_and_verify_1000,
prio3_client_count,
prio3_client_histogram_11,
prio3_client_sum_32,
prio3_client_count_vec_1000,
prio3_client_count_vec_multithreaded_1000,
);
} else {
iai::main!(
prng_16,
prng_256,
prng_1024,
prng_4096,
prio2_prove_10,
prio2_prove_100,
prio2_prove_1000,
prio2_prove_and_verify_10,
prio2_prove_and_verify_100,
prio2_prove_and_verify_1000,
prio3_client_count,
prio3_client_histogram_11,
prio3_client_sum_32,
prio3_client_count_vec_1000,
);
}
}
} else {
cfg_if! {
if #[cfg(feature = "multithreaded")] {
iai::main!(
prng_16,
prng_256,
prng_1024,
prng_4096,
prio3_client_count,
prio3_client_histogram_11,
prio3_client_sum_32,
prio3_client_count_vec_1000,
prio3_client_count_vec_multithreaded_1000,
);
} else {
iai::main!(
prng_16,
prng_256,
prng_1024,
prng_4096,
prio3_client_count,
prio3_client_histogram_11,
prio3_client_sum_32,
prio3_client_count_vec_1000,
);
}
}
}
}
52 changes: 29 additions & 23 deletions benches/speed_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pub fn count_vec(c: &mut Criterion) {
benchmarked_v2_prove(&input, &mut client).len()
);

c.bench_function(&format!("prio2 prove, size={}", *size), |b| {
c.bench_function(&format!("prio2 prove, input size={}", *size), |b| {
b.iter(|| {
benchmarked_v2_prove(&input, &mut client);
})
Expand All @@ -111,7 +111,7 @@ pub fn count_vec(c: &mut Criterion) {
let mut validator: ValidationMemory<F> = ValidationMemory::new(input.len());
let eval_at = random_vector(1).unwrap()[0];

c.bench_function(&format!("prio2 query, size={}", *size), |b| {
c.bench_function(&format!("prio2 query, input size={}", *size), |b| {
b.iter(|| {
generate_verification_message(
input.len(),
Expand All @@ -133,29 +133,35 @@ pub fn count_vec(c: &mut Criterion) {

println!("prio3 countvec proof size={}\n", proof.len());

c.bench_function(&format!("prio3 countvec prove, size={}", *size), |b| {
b.iter(|| {
let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap();
count_vec.prove(&input, &prove_rand, &joint_rand).unwrap();
})
});
c.bench_function(
&format!("prio3 countvec prove, input size={}", *size),
|b| {
b.iter(|| {
let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap();
count_vec.prove(&input, &prove_rand, &joint_rand).unwrap();
})
},
);

c.bench_function(&format!("prio3 countvec query, size={}", *size), |b| {
b.iter(|| {
let query_rand = random_vector(count_vec.query_rand_len()).unwrap();
count_vec
.query(&input, &proof, &query_rand, &joint_rand, 1)
.unwrap();
})
});
c.bench_function(
&format!("prio3 countvec query, input size={}", *size),
|b| {
b.iter(|| {
let query_rand = random_vector(count_vec.query_rand_len()).unwrap();
count_vec
.query(&input, &proof, &query_rand, &joint_rand, 1)
.unwrap();
})
},
);

#[cfg(feature = "multithreaded")]
{
let count_vec: CountVec<F, ParallelSumMultithreaded<F, BlindPolyEval<F>>> =
CountVec::new(*size);

c.bench_function(
&format!("prio3 countvec multithreaded prove, size={}", *size),
&format!("prio3 countvec multithreaded prove, input size={}", *size),
|b| {
b.iter(|| {
let prove_rand = random_vector(count_vec.prove_rand_len()).unwrap();
Expand All @@ -165,7 +171,7 @@ pub fn count_vec(c: &mut Criterion) {
);

c.bench_function(
&format!("prio3 countvec multithreaded query, size={}", *size),
&format!("prio3 countvec multithreaded query, input size={}", *size),
|b| {
b.iter(|| {
let query_rand = random_vector(count_vec.query_rand_len()).unwrap();
Expand All @@ -186,7 +192,7 @@ pub fn prio3_client(c: &mut Criterion) {
let prio3 = Prio3::new_aes128_count(num_shares).unwrap();
let measurement = 1;
println!(
"prio3 count size = {}",
"prio3 count share size = {}",
prio3_input_share_size(&prio3.shard(&measurement).unwrap().1)
);
c.bench_function("prio3 count", |b| {
Expand All @@ -199,7 +205,7 @@ pub fn prio3_client(c: &mut Criterion) {
let prio3 = Prio3::new_aes128_histogram(num_shares, &buckets).unwrap();
let measurement = 17;
println!(
"prio3 histogram ({} buckets) size = {}",
"prio3 histogram ({} buckets) share size = {}",
buckets.len() + 1,
prio3_input_share_size(&prio3.shard(&measurement).unwrap().1)
);
Expand All @@ -216,7 +222,7 @@ pub fn prio3_client(c: &mut Criterion) {
let prio3 = Prio3::new_aes128_sum(num_shares, bits).unwrap();
let measurement = 1337;
println!(
"prio3 sum ({} bits) size = {}",
"prio3 sum ({} bits) share size = {}",
bits,
prio3_input_share_size(&prio3.shard(&measurement).unwrap().1)
);
Expand All @@ -230,7 +236,7 @@ pub fn prio3_client(c: &mut Criterion) {
let prio3 = Prio3::new_aes128_count_vec(num_shares, len).unwrap();
let measurement = vec![0; len];
println!(
"prio3 countvec ({} len) size = {}",
"prio3 countvec ({} len) share size = {}",
len,
prio3_input_share_size(&prio3.shard(&measurement).unwrap().1)
);
Expand All @@ -245,7 +251,7 @@ pub fn prio3_client(c: &mut Criterion) {
let prio3 = Prio3::new_aes128_count_vec_multithreaded(num_shares, len).unwrap();
let measurement = vec![0; len];
println!(
"prio3 countvec multithreaded ({} len) size = {}",
"prio3 countvec multithreaded ({} len) share size = {}",
len,
prio3_input_share_size(&prio3.shard(&measurement).unwrap().1)
);
Expand Down

0 comments on commit 4c159d3

Please sign in to comment.