Skip to content

Commit

Permalink
Remove non-constant-time comparisons of secret values.
Browse files Browse the repository at this point in the history
This affects the types representing:
  * Input shares
  * Preparation states
  * Output shares
  * Aggregate shares

Mostly, the comparisons were either dropped entirely or updated to be
test-only.  Input shares were instead given a constant-time equality
implementation, as it is believed this is required for DAP
implementations.
  • Loading branch information
branlwyd committed Sep 6, 2023
1 parent aa1a9a4 commit 55ba66b
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 22 deletions.
21 changes: 18 additions & 3 deletions src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
};
use serde::{Deserialize, Serialize};
use std::{fmt::Debug, io::Cursor};
use subtle::{Choice, ConstantTimeEq};

/// A component of the domain-separation tag, used to bind the VDAF operations to the document
/// version. This will be revised with each draft with breaking changes.
Expand Down Expand Up @@ -57,7 +58,9 @@ pub enum VdafError {
}

/// An additive share of a vector of field elements.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub enum Share<F, const SEED_SIZE: usize> {
/// An uncompressed share, typically sent to the leader.
Leader(Vec<F>),
Expand All @@ -78,6 +81,18 @@ impl<F: Clone, const SEED_SIZE: usize> Share<F, SEED_SIZE> {
}
}

impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Share<F, SEED_SIZE> {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
// We allow short-circuiting on the type (Leader vs Helper) of the value, but not the types'
// contents.
match (self, other) {
(Share::Leader(self_val), Share::Leader(other_val)) => self_val.ct_eq(other_val),
(Share::Helper(self_val), Share::Helper(other_val)) => self_val.ct_eq(other_val),
_ => Choice::from(0),
}
}
}

/// Parameters needed to decode a [`Share`]
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum ShareDecodingParameter<const SEED_SIZE: usize> {
Expand Down Expand Up @@ -310,7 +325,7 @@ pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> {
}

/// An output share comprised of a vector of field elements.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
pub struct OutputShare<F>(Vec<F>);

impl<F> AsRef<[F]> for OutputShare<F> {
Expand Down Expand Up @@ -339,7 +354,7 @@ impl<F: FieldElement> Encode for OutputShare<F> {
///
/// This is suitable for VDAFs where both output shares and aggregate shares are vectors of field
/// elements, and output shares need no special transformation to be merged into an aggregate share.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AggregateShare<F>(Vec<F>);

impl<F: FieldElement> AsRef<[F]> for AggregateShare<F> {
Expand Down
38 changes: 33 additions & 5 deletions src/vdaf/poplar1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ impl<P, const SEED_SIZE: usize> ParameterizedDecode<Poplar1<P, SEED_SIZE>> for P
///
/// This is comprised of an IDPF key share and the correlated randomness used to compute the sketch
/// during preparation.
#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct Poplar1InputShare<const SEED_SIZE: usize> {
/// IDPF key share.
idpf_key: Seed<16>,
Expand All @@ -128,6 +130,24 @@ pub struct Poplar1InputShare<const SEED_SIZE: usize> {
corr_leaf: [Field255; 2],
}

impl<const SEED_SIZE: usize> ConstantTimeEq for Poplar1InputShare<SEED_SIZE> {
fn ct_eq(&self, other: &Self) -> Choice {
// We short-circuit on the length of corr_inner being different. Only the content is
// protected.
if self.corr_inner.len() != other.corr_inner.len() {
return Choice::from(0);
}

let mut rslt = self.idpf_key.ct_eq(&other.idpf_key)
& self.corr_seed.ct_eq(&other.corr_seed)
& self.corr_leaf.ct_eq(&other.corr_leaf);
for (x, y) in self.corr_inner.iter().zip(other.corr_inner.iter()) {
rslt &= x.ct_eq(y);
}
rslt
}
}

impl<const SEED_SIZE: usize> Encode for Poplar1InputShare<SEED_SIZE> {
fn encode(&self, bytes: &mut Vec<u8>) {
self.idpf_key.encode(bytes);
Expand Down Expand Up @@ -174,7 +194,9 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZ
}

/// Poplar1 preparation state.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct Poplar1PrepareState(PrepareStateVariant);

impl Encode for Poplar1PrepareState {
Expand All @@ -201,7 +223,9 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZ
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
enum PrepareStateVariant {
Inner(PrepareState<Field64>),
Leaf(PrepareState<Field255>),
Expand Down Expand Up @@ -252,7 +276,9 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZ
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
struct PrepareState<F> {
sketch: SketchState<F>,
output_share: Vec<F>,
Expand Down Expand Up @@ -450,7 +476,9 @@ impl ParameterizedDecode<Poplar1PrepareState> for Poplar1PrepareMessage {
}

/// A vector of field elements transmitted while evaluating Poplar1.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub enum Poplar1FieldVec {
/// Field type for inner nodes of the IDPF tree.
Inner(Vec<Field64>),
Expand Down
4 changes: 3 additions & 1 deletion src/vdaf/prio2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ impl Client<16> for Prio2 {
}

/// State of each [`Aggregator`] during the Preparation phase.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct Prio2PrepareState(Share<FieldPrio2, 32>);

impl Encode for Prio2PrepareState {
Expand Down
61 changes: 55 additions & 6 deletions src/vdaf/prio3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ use std::fmt::Debug;
use std::io::Cursor;
use std::iter::{self, IntoIterator};
use std::marker::PhantomData;
use subtle::{Choice, ConstantTimeEq};

const DST_MEASUREMENT_SHARE: u16 = 1;
const DST_PROOF_SHARE: u16 = 2;
Expand Down Expand Up @@ -595,7 +596,7 @@ where
}

/// Message broadcast by the [`Client`] to every [`Aggregator`] during the Sharding phase.
#[derive(Clone, Debug, PartialEq, Eq)]
#[derive(Clone, Debug)]
pub struct Prio3PublicShare<const SEED_SIZE: usize> {
/// Contributions to the joint randomness from every aggregator's share.
joint_rand_parts: Option<Vec<Seed<SEED_SIZE>>>,
Expand All @@ -620,6 +621,22 @@ impl<const SEED_SIZE: usize> Encode for Prio3PublicShare<SEED_SIZE> {
}
}

impl<const SEED_SIZE: usize> PartialEq for Prio3PublicShare<SEED_SIZE> {
fn eq(&self, other: &Self) -> bool {
// Handle case that both join_rand_parts are populated.
if let Some(self_joint_rand_parts) = &self.joint_rand_parts {
if let Some(other_joint_rand_parts) = &other.joint_rand_parts {
return self_joint_rand_parts.ct_eq(&other_joint_rand_parts).into();
}
}

// Handle case that at least one joint_rand_parts is not populated.
self.joint_rand_parts.is_none() && other.joint_rand_parts.is_none()
}
}

impl<const SEED_SIZE: usize> Eq for Prio3PublicShare<SEED_SIZE> {}

impl<T, P, const SEED_SIZE: usize> ParameterizedDecode<Prio3<T, P, SEED_SIZE>>
for Prio3PublicShare<SEED_SIZE>
where
Expand All @@ -646,7 +663,9 @@ where
}

/// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct Prio3InputShare<F, const SEED_SIZE: usize> {
/// The measurement share.
measurement_share: Share<F, SEED_SIZE>,
Expand All @@ -659,6 +678,24 @@ pub struct Prio3InputShare<F, const SEED_SIZE: usize> {
joint_rand_blind: Option<Seed<SEED_SIZE>>,
}

impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3InputShare<F, SEED_SIZE> {
fn ct_eq(&self, other: &Self) -> Choice {
// We allow short-circuiting on the existence (but not contents) of the joint_rand_blind,
// as its existence is a property of the type in use.
let joint_rand_eq = match (&self.joint_rand_blind, &other.joint_rand_blind) {
(Some(self_joint_rand), Some(other_joint_rand)) => {
self_joint_rand.ct_eq(other_joint_rand)
}
(None, None) => Choice::from(1),
_ => Choice::from(0),
};

joint_rand_eq
& self.measurement_share.ct_eq(&other.measurement_share)
& self.proof_share.ct_eq(&other.proof_share)
}
}

impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode for Prio3InputShare<F, SEED_SIZE> {
fn encode(&self, bytes: &mut Vec<u8>) {
if matches!(
Expand Down Expand Up @@ -726,7 +763,9 @@ where
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
/// Message broadcast by each [`Aggregator`] in each round of the Preparation phase.
pub struct Prio3PrepareShare<F, const SEED_SIZE: usize> {
/// A share of the FLP verifier message. (See [`Type`].)
Expand Down Expand Up @@ -783,7 +822,9 @@ impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize>
}
}

#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
/// Result of combining a round of [`Prio3PrepareShare`] messages.
pub struct Prio3PrepareMessage<const SEED_SIZE: usize> {
/// The joint randomness seed computed by the Aggregators.
Expand Down Expand Up @@ -841,7 +882,9 @@ where
}

/// State of each [`Aggregator`] during the Preparation phase.
#[derive(Clone, Debug, Eq, PartialEq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct Prio3PrepareState<F, const SEED_SIZE: usize> {
measurement_share: Share<F, SEED_SIZE>,
joint_rand_seed: Option<Seed<SEED_SIZE>>,
Expand Down Expand Up @@ -1111,7 +1154,13 @@ where
) -> Result<PrepareTransition<Self, SEED_SIZE, 16>, VdafError> {
if self.typ.joint_rand_len() > 0 {
// Check that the joint randomness was correct.
if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() {
if (!step
.joint_rand_seed
.as_ref()
.unwrap()
.ct_eq(msg.joint_rand_seed.as_ref().unwrap()))
.into()
{
return Err(VdafError::Uncategorized(
"joint randomness mismatch".to_string(),
));
Expand Down
10 changes: 3 additions & 7 deletions src/vdaf/xof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ use std::{
use subtle::{Choice, ConstantTimeEq};

/// Input of [`Xof`].
#[derive(Clone, Debug, Eq)]
#[derive(Clone, Debug)]
// Only derive equality checks in test code, as the content of this type is sometimes a secret.
#[cfg_attr(test, derive(PartialEq, Eq))]
pub struct Seed<const SEED_SIZE: usize>(pub(crate) [u8; SEED_SIZE]);

impl<const SEED_SIZE: usize> Seed<SEED_SIZE> {
Expand Down Expand Up @@ -67,12 +69,6 @@ impl<const SEED_SIZE: usize> ConstantTimeEq for Seed<SEED_SIZE> {
}
}

impl<const SEED_SIZE: usize> PartialEq for Seed<SEED_SIZE> {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}

impl<const SEED_SIZE: usize> Encode for Seed<SEED_SIZE> {
fn encode(&self, bytes: &mut Vec<u8>) {
bytes.extend_from_slice(&self.0[..]);
Expand Down

0 comments on commit 55ba66b

Please sign in to comment.