diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 284c81df..26dc4963 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -210,7 +210,7 @@ jobs: uses: actions-rs/cargo@v1 with: command: clippy - args: --all-targets -- -D warnings + args: --all-targets --features p256,slow-hash,std,x25519_u64 -- -D warnings - name: Run cargo doc uses: actions-rs/cargo@v1 @@ -218,7 +218,7 @@ jobs: RUSTDOCFLAGS: -D warnings with: command: doc - args: --no-deps --document-private-items --features p256,slow-hash,std + args: --no-deps --document-private-items --features p256,slow-hash,std,x25519_u64 format: name: cargo fmt diff --git a/benches/opaque.rs b/benches/opaque.rs index 200a5717..4cdd64f8 100644 --- a/benches/opaque.rs +++ b/benches/opaque.rs @@ -28,7 +28,7 @@ struct Default; #[cfg(feature = "ristretto255")] impl CipherSuite for Default { type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; - type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; + type KeGroup = opaque_ke::Ristretto255; type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; type Hash = sha2::Sha512; type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -37,7 +37,7 @@ impl CipherSuite for Default { #[cfg(not(feature = "ristretto255"))] impl CipherSuite for Default { type OprfGroup = p256_::ProjectivePoint; - type KeGroup = p256_::PublicKey; + type KeGroup = p256_::NistP256; type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; type Hash = sha2::Sha256; type SlowHash = opaque_ke::slow_hash::NoOpHash; diff --git a/examples/digital_locker.rs b/examples/digital_locker.rs index c04effc0..f6118cca 100644 --- a/examples/digital_locker.rs +++ b/examples/digital_locker.rs @@ -50,7 +50,7 @@ struct Default; #[cfg(feature = "ristretto255")] impl CipherSuite for Default { type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; - type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; + type KeGroup = opaque_ke::Ristretto255; type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; type Hash = sha2::Sha512; type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -59,7 +59,7 @@ impl CipherSuite for Default { #[cfg(not(feature = "ristretto255"))] impl CipherSuite for Default { type OprfGroup = p256_::ProjectivePoint; - type KeGroup = p256_::PublicKey; + type KeGroup = p256_::NistP256; type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; type Hash = sha2::Sha256; type SlowHash = opaque_ke::slow_hash::NoOpHash; diff --git a/examples/simple_login.rs b/examples/simple_login.rs index 4cb50220..7b55eaba 100644 --- a/examples/simple_login.rs +++ b/examples/simple_login.rs @@ -44,7 +44,7 @@ struct Default; #[cfg(feature = "ristretto255")] impl CipherSuite for Default { type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; - type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; + type KeGroup = opaque_ke::Ristretto255; type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; type Hash = sha2::Sha512; type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -53,7 +53,7 @@ impl CipherSuite for Default { #[cfg(not(feature = "ristretto255"))] impl CipherSuite for Default { type OprfGroup = p256_::ProjectivePoint; - type KeGroup = p256_::PublicKey; + type KeGroup = p256_::NistP256; type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; type Hash = sha2::Sha256; type SlowHash = opaque_ke::slow_hash::NoOpHash; diff --git a/src/envelope.rs b/src/envelope.rs index 31e7c7e9..93d8f8eb 100755 --- a/src/envelope.rs +++ b/src/envelope.rs @@ -36,13 +36,18 @@ const STR_PRIVATE_KEY: [u8; 10] = *b"PrivateKey"; const STR_OPAQUE_DERIVE_AUTH_KEY_PAIR: [u8; 24] = *b"OPAQUE-DeriveAuthKeyPair"; type NonceLen = U32; -#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize)] -#[zeroize(drop)] +#[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub(crate) enum InnerEnvelopeMode { Zero = 0, Internal = 1, } +impl Zeroize for InnerEnvelopeMode { + fn zeroize(&mut self) { + *self = Self::Zero + } +} + impl TryFrom for InnerEnvelopeMode { type Error = ProtocolError; fn try_from(x: u8) -> Result { @@ -69,7 +74,7 @@ where <::Core as BlockSizeUser>::BlockSize: IsLess, Le<<::Core as BlockSizeUser>::BlockSize, U256>: NonZero, { - mode: InnerEnvelopeMode, + pub(crate) mode: InnerEnvelopeMode, nonce: GenericArray, hmac: Output, } @@ -136,12 +141,13 @@ where build_inner_envelope_internal::(randomized_pwd_hasher.clone(), nonce)?, ); + let server_s_pk_bytes = server_s_pk.to_bytes(); let (id_u, id_s) = bytestrings_from_identifiers::( ids, - client_s_pk.to_arr(), - server_s_pk.to_arr(), + client_s_pk.to_bytes(), + server_s_pk_bytes.clone(), )?; - let aad = construct_aad(id_u.iter(), id_s.iter(), server_s_pk); + let aad = construct_aad(id_u.iter(), id_s.iter(), &server_s_pk_bytes); let result = Self::seal_raw(randomized_pwd_hasher, nonce, aad, mode)?; Ok(( @@ -206,12 +212,13 @@ where } }; + let server_s_pk_bytes = server_s_pk.to_bytes(); let (id_u, id_s) = bytestrings_from_identifiers::( optional_ids, - client_static_keypair.public().to_arr(), - server_s_pk.to_arr(), + client_static_keypair.public().to_bytes(), + server_s_pk_bytes.clone(), )?; - let aad = construct_aad(id_u.iter(), id_s.iter(), &server_s_pk); + let aad = construct_aad(id_u.iter(), id_s.iter(), &server_s_pk_bytes); let opened = self.open_raw(randomized_pwd_hasher, aad)?; @@ -318,6 +325,7 @@ where .expand(&nonce.concat(STR_PRIVATE_KEY.into()), &mut keypair_seed) .map_err(|_| InternalError::HkdfError)?; let client_static_keypair = KeyPair::::from_private_key_slice( + // TODO: Use `KeGroup` instead of `OprfGroup` here. &CS::OprfGroup::scalar_as_bytes(CS::OprfGroup::hash_to_scalar::( [keypair_seed.as_slice()], GenericArray::from(STR_OPAQUE_DERIVE_AUTH_KEY_PAIR), diff --git a/src/key_exchange/group/mod.rs b/src/key_exchange/group/mod.rs index dad51636..db2bc991 100644 --- a/src/key_exchange/group/mod.rs +++ b/src/key_exchange/group/mod.rs @@ -13,26 +13,39 @@ use rand::{CryptoRng, RngCore}; use crate::errors::InternalError; /// A group representation for use in the key exchange -pub trait KeGroup: Sized + Clone { +pub trait KeGroup { + /// Public key + type Pk: Clone + Sized; /// Length of the public key type PkLen: ArrayLength + 'static; + /// Secret key + type Sk: Clone + Sized; /// Length of the secret key type SkLen: ArrayLength + 'static; + /// Serializes `self` + fn serialize_pk(pk: &Self::Pk) -> GenericArray; + /// Return a public key from its fixed-length bytes representation - fn from_pk_slice(element_bits: &GenericArray) -> Result; + fn deserialize_pk(bytes: &GenericArray) -> Result; /// Generate a random secret key - fn random_sk(rng: &mut R) -> GenericArray; + fn random_sk(rng: &mut R) -> Self::Sk; /// Return a public key from its secret key - fn public_key(sk: &GenericArray) -> Self; + fn public_key(sk: &Self::Sk) -> Self::Pk; + + /// Diffie-Hellman key exchange + fn diffie_hellman(pk: &Self::Pk, sk: &Self::Sk) -> GenericArray; + + /// Zeroize secret key on drop. + fn zeroize_sk_on_drop(sk: &mut Self::Sk); /// Serializes `self` - fn to_arr(&self) -> GenericArray; + fn serialize_sk(sk: &Self::Sk) -> GenericArray; - /// Diffie-Hellman key exchange - fn diffie_hellman(&self, sk: &GenericArray) -> GenericArray; + /// Return a public key from its fixed-length bytes representation + fn deserialize_sk(bytes: &GenericArray) -> Result; } #[cfg(feature = "p256")] diff --git a/src/key_exchange/group/p256.rs b/src/key_exchange/group/p256.rs index 50eb9cb7..6b43ad47 100644 --- a/src/key_exchange/group/p256.rs +++ b/src/key_exchange/group/p256.rs @@ -9,44 +9,49 @@ use generic_array::typenum::{U32, U33}; use generic_array::GenericArray; -use p256_::elliptic_curve::group::GroupEncoding; -use p256_::elliptic_curve::sec1::ToEncodedPoint; -use p256_::elliptic_curve::{PublicKey, SecretKey}; -use p256_::NistP256; +use p256::elliptic_curve::group::GroupEncoding; +use p256::elliptic_curve::sec1::ToEncodedPoint; +use p256::elliptic_curve::{PublicKey, SecretKey}; +use p256::NistP256; use rand::{CryptoRng, RngCore}; use super::KeGroup; use crate::errors::InternalError; -impl KeGroup for PublicKey { +impl KeGroup for NistP256 { + type Pk = PublicKey; type PkLen = U33; + type Sk = SecretKey; type SkLen = U32; + fn serialize_pk(pk: &Self::Pk) -> GenericArray { + GenericArray::clone_from_slice(pk.to_encoded_point(true).as_bytes()) + } + + fn deserialize_pk(bytes: &GenericArray) -> Result { + Self::Pk::from_sec1_bytes(bytes).map_err(|_| InternalError::PointError) + } - fn from_pk_slice(element_bits: &GenericArray) -> Result { - Self::from_sec1_bytes(element_bits).map_err(|_| InternalError::PointError) + fn random_sk(rng: &mut R) -> Self::Sk { + SecretKey::::random(rng) } - fn random_sk(rng: &mut R) -> GenericArray { - SecretKey::::random(rng).to_be_bytes() + fn public_key(sk: &Self::Sk) -> Self::Pk { + sk.public_key() } - fn public_key(sk: &GenericArray) -> Self { - SecretKey::::from_be_bytes(sk) - .unwrap() - .public_key() + fn diffie_hellman(pk: &Self::Pk, sk: &Self::Sk) -> GenericArray { + (pk.to_projective() * sk.to_nonzero_scalar().as_ref()) + .to_affine() + .to_bytes() } - fn to_arr(&self) -> GenericArray { - GenericArray::clone_from_slice(self.to_encoded_point(true).as_bytes()) + fn zeroize_sk_on_drop(_: &mut Self::Sk) {} + + fn serialize_sk(sk: &Self::Sk) -> GenericArray { + sk.to_be_bytes() } - fn diffie_hellman(&self, sk: &GenericArray) -> GenericArray { - (self.to_projective() - * SecretKey::::from_be_bytes(sk) - .unwrap() - .to_nonzero_scalar() - .as_ref()) - .to_affine() - .to_bytes() + fn deserialize_sk(bytes: &GenericArray) -> Result { + Self::Sk::from_be_bytes(bytes).map_err(|_| InternalError::PointError) } } diff --git a/src/key_exchange/group/ristretto255.rs b/src/key_exchange/group/ristretto255.rs index c4c6ebd7..53598419 100644 --- a/src/key_exchange/group/ristretto255.rs +++ b/src/key_exchange/group/ristretto255.rs @@ -13,21 +13,31 @@ use curve25519_dalek::scalar::Scalar; use generic_array::typenum::U32; use generic_array::GenericArray; use rand::{CryptoRng, RngCore}; +use zeroize::Zeroize; use super::KeGroup; use crate::errors::InternalError; -impl KeGroup for RistrettoPoint { +/// Implementation for Ristretto255. +pub struct Ristretto255; + +impl KeGroup for Ristretto255 { + type Pk = RistrettoPoint; type PkLen = U32; + type Sk = Scalar; type SkLen = U32; - fn from_pk_slice(element_bits: &GenericArray) -> Result { - CompressedRistretto::from_slice(element_bits) + fn serialize_pk(pk: &Self::Pk) -> GenericArray { + pk.compress().to_bytes().into() + } + + fn deserialize_pk(bytes: &GenericArray) -> Result { + CompressedRistretto::from_slice(bytes) .decompress() .ok_or(InternalError::PointError) } - fn random_sk(rng: &mut R) -> GenericArray { + fn random_sk(rng: &mut R) -> Self::Sk { loop { let scalar = { #[cfg(not(test))] @@ -47,21 +57,33 @@ impl KeGroup for RistrettoPoint { } }; - if scalar != Scalar::zero() { - break scalar.to_bytes().into(); + if scalar != Scalar::zero() && scalar.is_canonical() { + break scalar; } } } - fn public_key(sk: &GenericArray) -> Self { - RISTRETTO_BASEPOINT_POINT * Scalar::from_bits(*sk.as_ref()) + fn public_key(sk: &Self::Sk) -> Self::Pk { + RISTRETTO_BASEPOINT_POINT * sk + } + + fn diffie_hellman(pk: &Self::Pk, sk: &Self::Sk) -> GenericArray { + Self::serialize_pk(&(pk * sk)) + } + + fn zeroize_sk_on_drop(sk: &mut Self::Sk) { + sk.zeroize() } - fn to_arr(&self) -> GenericArray { - self.compress().to_bytes().into() + fn serialize_sk(sk: &Self::Sk) -> GenericArray { + sk.to_bytes().into() } - fn diffie_hellman(&self, sk: &GenericArray) -> GenericArray { - (self * Scalar::from_bits(*sk.as_ref())).to_arr() + fn deserialize_sk(bytes: &GenericArray) -> Result { + // TODO: When we implement `hash_to_field` we can re-enable this again. + //Scalar::from_canonical_bytes((*bytes).into()).ok_or(InternalError:: + // PointError) + + Ok(Scalar::from_bits((*bytes).into())) } } diff --git a/src/key_exchange/group/x25519.rs b/src/key_exchange/group/x25519.rs index 18d2067d..060d0ae2 100644 --- a/src/key_exchange/group/x25519.rs +++ b/src/key_exchange/group/x25519.rs @@ -11,43 +11,76 @@ use generic_array::typenum::U32; use generic_array::GenericArray; use rand::{CryptoRng, RngCore}; use x25519_dalek::{PublicKey, StaticSecret}; +use zeroize::Zeroize; use super::KeGroup; use crate::errors::InternalError; +/// Implementation for X25519. +pub struct X25519; + /// The implementation of such a subgroup for Ristretto -impl KeGroup for PublicKey { +impl KeGroup for X25519 { + type Pk = PublicKey; type PkLen = U32; + type Sk = StaticSecret; type SkLen = U32; - fn from_pk_slice(element_bits: &GenericArray) -> Result { - Ok(Self::from(<[u8; 32]>::from(*element_bits))) + fn serialize_pk(pk: &Self::Pk) -> GenericArray { + pk.to_bytes().into() + } + + fn deserialize_pk(bytes: &GenericArray) -> Result { + if **bytes == [0; 32] { + Err(InternalError::PointError) + } else { + Ok(PublicKey::from(<[_; 32]>::from(*bytes))) + } } - fn random_sk(rng: &mut R) -> GenericArray { + fn random_sk(rng: &mut R) -> Self::Sk { let mut scalar_bytes = [0u8; 32]; loop { rng.fill_bytes(&mut scalar_bytes); if scalar_bytes != [0u8; 32] { - break StaticSecret::from(scalar_bytes).to_bytes().into(); + break StaticSecret::from(scalar_bytes); } } } - fn public_key(sk: &GenericArray) -> Self { - Self::from(&StaticSecret::from(<[u8; 32]>::from(*sk))) + fn public_key(sk: &Self::Sk) -> Self::Pk { + PublicKey::from(sk) + } + + fn diffie_hellman(pk: &Self::Pk, sk: &Self::Sk) -> GenericArray { + sk.diffie_hellman(pk).to_bytes().into() } - fn to_arr(&self) -> GenericArray { - self.to_bytes().into() + fn zeroize_sk_on_drop(sk: &mut Self::Sk) { + sk.zeroize() } - fn diffie_hellman(&self, sk: &GenericArray) -> GenericArray { - StaticSecret::from(<[u8; 32]>::from(*sk)) - .diffie_hellman(self) - .to_bytes() - .into() + fn serialize_sk(sk: &Self::Sk) -> GenericArray { + sk.to_bytes().into() + } + + fn deserialize_sk(bytes: &GenericArray) -> Result { + if **bytes == [0; 32] { + Err(InternalError::PointError) + } else { + let sk = StaticSecret::from(<[u8; 32]>::from(*bytes)); + + // TODO: When we implement `hash_to_field` we can re-enable this again. + // If any clamping was applied. + //if sk.to_bytes() == **bytes { + // Ok(sk) + //} else { + // Err(InternalError::PointError) + //} + + Ok(sk) + } } } diff --git a/src/key_exchange/tripledh.rs b/src/key_exchange/tripledh.rs index 65935f18..204e379e 100755 --- a/src/key_exchange/tripledh.rs +++ b/src/key_exchange/tripledh.rs @@ -55,11 +55,19 @@ pub struct TripleDH; #[cfg_attr( feature = "serde", derive(serde_::Deserialize, serde_::Serialize), - serde(bound = "", crate = "serde_") + serde( + bound( + deserialize = "KG::Sk: serde_::Deserialize<'de>", + serialize = "KG::Sk: serde_::Serialize", + ), + crate = "serde_" + ) )] #[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize(drop))] +#[derive_where(Clone, Zeroize(drop))] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Sk)] pub struct Ke1State { + #[derive_where(skip(Zeroize))] client_e_sk: PrivateKey, client_nonce: GenericArray, } @@ -68,12 +76,20 @@ pub struct Ke1State { #[cfg_attr( feature = "serde", derive(serde_::Deserialize, serde_::Serialize), - serde(bound = "", crate = "serde_") + serde( + bound( + deserialize = "KG::Pk: serde_::Deserialize<'de>", + serialize = "KG::Pk: serde_::Serialize", + ), + crate = "serde_" + ) )] #[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize)] +#[derive_where(Clone, Zeroize(drop))] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Pk)] pub struct Ke1Message { pub(crate) client_nonce: GenericArray, + #[derive_where(skip(Zeroize))] pub(crate) client_e_pk: PublicKey, } @@ -100,10 +116,17 @@ where #[cfg_attr( feature = "serde", derive(serde_::Deserialize, serde_::Serialize), - serde(bound = "", crate = "serde_") + serde( + bound( + deserialize = "KG::Pk: serde_::Deserialize<'de>", + serialize = "KG::Pk: serde_::Serialize", + ), + crate = "serde_" + ) )] #[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive_where(Clone)] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Pk)] pub struct Ke2Message where D::Core: ProxyHash, @@ -210,13 +233,13 @@ where .chain_iter(id_s.into_iter()) .chain_iter(l2_bytes) .chain(server_nonce) - .chain(&server_e_kp.public().to_arr()); + .chain(&server_e_kp.public().to_bytes()); let result = derive_3dh_keys::( TripleDHComponents { pk1: ke1_message.client_e_pk.clone(), sk1: server_e_kp.private().clone(), - pk2: ke1_message.client_e_pk, + pk2: ke1_message.client_e_pk.clone(), sk2: server_s_sk, pk3: client_s_pk, sk3: server_e_kp.private().clone(), @@ -268,7 +291,7 @@ where .chain_iter(serialized_credential_request) .chain_iter(id_s) .chain_iter(l2_component) - .chain_iter(ke2_message.to_bytes_without_info_or_mac()); + .chain(ke2_message.to_bytes_without_mac()); let result = derive_3dh_keys::>( TripleDHComponents { @@ -489,7 +512,7 @@ impl FromBytes for Ke1State { let checked_bytes = check_slice_size_atleast(bytes, key_len + nonce_len, "ke1_state")?; Ok(Self { - client_e_sk: PrivateKey::from_bytes(&checked_bytes[..key_len])?, + client_e_sk: PrivateKey::deserialize(&checked_bytes[..key_len])?, client_nonce: GenericArray::clone_from_slice( &checked_bytes[key_len..key_len + nonce_len], ), @@ -506,7 +529,7 @@ where type Len = Sum; fn to_bytes(&self) -> GenericArray { - self.client_e_sk.to_arr().concat(self.client_nonce) + self.client_e_sk.serialize().concat(self.client_nonce) } } @@ -521,7 +544,7 @@ impl FromBytes for Ke1Message { Ok(Self { client_nonce: GenericArray::clone_from_slice(&checked_nonce[..nonce_len]), - client_e_pk: PublicKey::from_bytes(&checked_nonce[nonce_len..])?, + client_e_pk: PublicKey::deserialize(&checked_nonce[nonce_len..])?, }) } } @@ -535,7 +558,7 @@ where type Len = Sum; fn to_bytes(&self) -> GenericArray { - self.client_nonce.concat(self.client_e_pk.to_arr()) + self.client_nonce.concat(self.client_e_pk.to_bytes()) } } @@ -602,13 +625,11 @@ where )?; // Check the public key bytes - let server_e_pk = KeyPair::::check_public_key(PublicKey::from_bytes( - &unchecked_server_e_pk[..key_len], - )?)?; + let server_e_pk = PublicKey::deserialize(&unchecked_server_e_pk[..key_len])?; Ok(Self { server_nonce: GenericArray::clone_from_slice(&checked_nonce[..nonce_len]), - server_e_pk: PublicKey::from_bytes(&server_e_pk)?, + server_e_pk, mac: GenericArray::clone_from_slice(checked_mac), }) } @@ -628,7 +649,7 @@ where fn to_bytes(&self) -> GenericArray { self.server_nonce - .concat(self.server_e_pk.to_arr()) + .concat(self.server_e_pk.to_bytes()) .concat(self.mac.clone()) } } @@ -638,9 +659,11 @@ where D::Core: ProxyHash, ::BlockSize: IsLess, Le<::BlockSize, U256>: NonZero, + NonceLen: Add, + Sum: ArrayLength, { - fn to_bytes_without_info_or_mac(&self) -> impl Iterator { - [self.server_nonce.as_slice(), self.server_e_pk.as_slice()].into_iter() + fn to_bytes_without_mac(&self) -> GenericArray> { + self.server_nonce.concat(self.server_e_pk.to_bytes()) } } diff --git a/src/keypair.rs b/src/keypair.rs index 40fa8c87..19bac6ad 100644 --- a/src/keypair.rs +++ b/src/keypair.rs @@ -9,16 +9,13 @@ #![allow(unsafe_code)] -use core::ops::Deref; - use derive_where::DeriveWhere; -use generic_array::typenum::Unsigned; use generic_array::{ArrayLength, GenericArray}; use rand::{CryptoRng, RngCore}; -use zeroize::Zeroize; use crate::errors::{InternalError, ProtocolError}; use crate::key_exchange::group::KeGroup; +use crate::serialization::GenericArrayExt; /// A Keypair trait with public-private verification #[cfg_attr( @@ -26,15 +23,15 @@ use crate::key_exchange::group::KeGroup; derive(serde_::Deserialize, serde_::Serialize), serde( bound( - deserialize = "S: serde_::Deserialize<'de>", - serialize = "S: serde_::Serialize" + deserialize = "KG::Pk: serde_::Deserialize<'de>, S: serde_::Deserialize<'de>", + serialize = "KG::Pk: serde_::Serialize, S: serde_::Serialize" ), crate = "serde_" ) )] #[derive(DeriveWhere)] -#[derive_where(Clone, Zeroize(drop))] -#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; S)] +#[derive_where(Clone)] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Pk, S)] pub struct KeyPair = PrivateKey> { pk: PublicKey, sk: S, @@ -51,14 +48,6 @@ impl> KeyPair { &self.sk } - /// Check whether a public key is valid. This is meant to be applied on - /// material provided through the network which fits the key representation - /// (i.e. can be mapped to a curve point), but presents some risk - e.g. - /// small subgroup check - pub(crate) fn check_public_key(key: PublicKey) -> Result, InternalError> { - KG::from_pk_slice(GenericArray::from_slice(&key.0)).map(|_| key) - } - /// Obtains a KeyPair from a slice representing the private key pub fn from_private_key_slice(input: &[u8]) -> Result> { Self::from_private_key(S::deserialize(input)?) @@ -77,14 +66,18 @@ impl KeyPair { let sk = KG::random_sk(rng); let pk = KG::public_key(&sk); Self { - pk: PublicKey(Key(pk.to_arr())), - sk: PrivateKey(Key(sk)), + pk: PublicKey(pk), + sk: PrivateKey(sk), } } } #[cfg(test)] -impl KeyPair { +impl KeyPair +where + KG::Pk: std::fmt::Debug, + KG::Sk: std::fmt::Debug, +{ /// Test-only strategy returning a proptest Strategy based on /// generate_random fn uniform_keypair_strategy() -> proptest::prelude::BoxedStrategy { @@ -104,70 +97,38 @@ impl KeyPair { } } -/// A minimalist key type built around a \[u8; 32\] -#[cfg_attr( - feature = "serde", - derive(serde_::Deserialize, serde_::Serialize), - serde(bound = "", crate = "serde_") -)] -#[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize(drop))] -pub struct Key>(GenericArray); - -impl> Deref for Key { - type Target = GenericArray; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -// Don't make it implement SizedBytes so that it's not constructible outside of -// this module. -impl> Key { - /// Convert to bytes - pub fn to_arr(&self) -> GenericArray { - self.0.clone() - } -} - /// Wrapper around a Key to enforce that it's a private one. #[cfg_attr( feature = "serde", derive(serde_::Deserialize, serde_::Serialize), - serde(bound = "", crate = "serde_") + serde( + bound( + deserialize = "KG::Sk: serde_::Deserialize<'de>", + serialize = "KG::Sk: serde_::Serialize" + ), + crate = "serde_" + ) )] #[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize(drop))] -pub struct PrivateKey(Key); - -// This can't be derived because of the use of a generic parameter -impl Deref for PrivateKey { - type Target = Key; +#[derive_where(Clone)] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Sk)] +pub struct PrivateKey(KG::Sk); - fn deref(&self) -> &Self::Target { - &self.0 +impl Drop for PrivateKey { + fn drop(&mut self) { + KG::zeroize_sk_on_drop(&mut self.0) } } impl PrivateKey { /// Convert from bytes - pub fn from_arr(key_bytes: GenericArray) -> Self { - PrivateKey(Key(key_bytes)) - } - - /// Convert from slice - pub fn from_bytes(key_bytes: &[u8]) -> Result { - if key_bytes.len() == KG::SkLen::USIZE { - Ok(Self::from_arr(GenericArray::from_slice(key_bytes).clone())) - } else { - Err(InternalError::InvalidByteSequence) - } + pub fn from_bytes(key_bytes: &GenericArray) -> Result { + KG::deserialize_sk(key_bytes).map(Self) } } /// A trait specifying the requirements for a private key container -pub trait SecretKey: Clone + Sized + Zeroize { +pub trait SecretKey: Clone + Sized { /// Custom error type that can be passed down to `InternalError::Custom` type Error; /// Serialization size in bytes. @@ -197,20 +158,19 @@ impl SecretKey for PrivateKey { &self, pk: PublicKey, ) -> Result, InternalError> { - let pk = KG::from_pk_slice(&pk)?; - Ok(pk.diffie_hellman(self)) + Ok(KG::diffie_hellman(&pk.0, &self.0)) } fn public_key(&self) -> Result, InternalError> { - Ok(PublicKey(Key(KG::public_key(&self.0).to_arr()))) + Ok(PublicKey(KG::public_key(&self.0))) } fn serialize(&self) -> GenericArray { - self.to_arr() + KG::serialize_sk(&self.0) } fn deserialize(input: &[u8]) -> Result { - PrivateKey::from_bytes(input).map_err(InternalError::from) + GenericArray::try_from_slice(input).and_then(Self::from_bytes) } } @@ -218,93 +178,56 @@ impl SecretKey for PrivateKey { #[cfg_attr( feature = "serde", derive(serde_::Deserialize, serde_::Serialize), - serde(bound = "", crate = "serde_") + serde( + bound( + deserialize = "KG::Pk: serde_::Deserialize<'de>", + serialize = "KG::Pk: serde_::Serialize" + ), + crate = "serde_" + ) )] #[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize(drop))] -pub struct PublicKey(Key); - -impl Deref for PublicKey { - type Target = Key; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} +#[derive_where(Clone)] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; KG::Pk)] +pub struct PublicKey(KG::Pk); impl PublicKey { /// Convert from bytes - pub fn from_arr(key_bytes: GenericArray) -> Self { - Self(Key(key_bytes)) + pub fn from_bytes(key_bytes: &GenericArray) -> Result { + KG::deserialize_pk(key_bytes).map(Self) + } + + /// Convert to bytes + pub fn to_bytes(&self) -> GenericArray { + KG::serialize_pk(&self.0) } /// Convert from slice - pub fn from_bytes(key_bytes: &[u8]) -> Result { - if key_bytes.len() == KG::PkLen::USIZE { - Ok(Self::from_arr(GenericArray::from_slice(key_bytes).clone())) - } else { - Err(InternalError::InvalidByteSequence) - } + pub fn deserialize(input: &[u8]) -> Result { + GenericArray::try_from_slice(input).and_then(Self::from_bytes) } } #[cfg(test)] mod tests { - use core::slice::from_raw_parts; - use std::vec; - - use generic_array::typenum::Unsigned; use rand::rngs::OsRng; use super::*; use crate::errors::*; + use crate::util; #[test] - fn test_zeroize_key() -> Result<(), ProtocolError> { - fn inner() -> Result<(), ProtocolError> { - let key_len = G::PkLen::USIZE; - let mut key = Key::(GenericArray::clone_from_slice(&vec![1u8; key_len])); - let ptr = key.as_ptr(); - - Zeroize::zeroize(&mut key); - - let bytes = unsafe { from_raw_parts(ptr, key_len) }; - assert!(bytes.iter().all(|&x| x == 0)); - - Ok(()) - } - - #[cfg(feature = "ristretto255")] - inner::()?; - #[cfg(feature = "p256")] - inner::()?; - - Ok(()) - } - - #[test] - fn test_zeroize_keypair() { + fn test_zeroize_key() { fn inner() { let mut rng = OsRng; - let mut keypair = KeyPair::::generate_random(&mut rng); - let pk_ptr = keypair.pk.as_ptr(); - let sk_ptr = keypair.sk.as_ptr(); - let pk_len = G::PkLen::USIZE; - let sk_len = G::SkLen::USIZE; - - Zeroize::zeroize(&mut keypair); - - let pk_bytes = unsafe { from_raw_parts(pk_ptr, pk_len) }; - let sk_bytes = unsafe { from_raw_parts(sk_ptr, sk_len) }; - - assert!(pk_bytes.iter().all(|&x| x == 0)); - assert!(sk_bytes.iter().all(|&x| x == 0)); + let mut key = PrivateKey::(G::random_sk(&mut rng)); + util::test_zeroize_on_drop(&mut key); } #[cfg(feature = "ristretto255")] - inner::(); + inner::(); #[cfg(feature = "p256")] - inner::(); + inner::<::p256::NistP256>(); } macro_rules! test { @@ -317,12 +240,6 @@ mod tests { use super::*; proptest! { - #[test] - fn check(kp in KeyPair::<$point>::uniform_keypair_strategy()) { - let pk = kp.public(); - prop_assert!(KeyPair::<$point>::check_public_key(pk.clone()).is_ok()); - } - #[test] fn pub_from_priv(kp in KeyPair::<$point>::uniform_keypair_strategy()) { let pk = kp.public(); @@ -342,10 +259,10 @@ mod tests { #[test] fn private_key_slice(kp in KeyPair::<$point>::uniform_keypair_strategy()) { - let sk_bytes = kp.private().to_vec(); + let sk_bytes = kp.private().serialize().to_vec(); let kp2 = KeyPair::<$point>::from_private_key_slice(&sk_bytes)?; - let kp2_private_bytes = kp2.private().to_vec(); + let kp2_private_bytes = kp2.private().serialize().to_vec(); prop_assert_eq!(sk_bytes, kp2_private_bytes); } @@ -355,16 +272,12 @@ mod tests { } #[cfg(feature = "ristretto255")] - test!(ristretto, curve25519_dalek::ristretto::RistrettoPoint); + test!(ristretto, crate::Ristretto255); #[cfg(feature = "p256")] - test!(p256, p256_::PublicKey); + test!(p256, ::p256::NistP256); #[test] fn remote_key() { - #[cfg(feature = "ristretto255")] - use curve25519_dalek::ristretto::RistrettoPoint as KeCurve; - #[cfg(not(feature = "ristretto255"))] - use p256_::PublicKey as KeCurve; use rand::rngs::OsRng; use crate::{ @@ -379,10 +292,13 @@ mod tests { impl CipherSuite for Default { #[cfg(feature = "ristretto255")] - type OprfGroup = KeCurve; + type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; + #[cfg(not(feature = "ristretto255"))] + type OprfGroup = ::p256::ProjectivePoint; + #[cfg(feature = "ristretto255")] + type KeGroup = crate::Ristretto255; #[cfg(not(feature = "ristretto255"))] - type OprfGroup = p256_::ProjectivePoint; - type KeGroup = KeCurve; + type KeGroup = ::p256::NistP256; type KeyExchange = crate::key_exchange::tripledh::TripleDH; #[cfg(feature = "ristretto255")] type Hash = sha2::Sha512; @@ -391,7 +307,9 @@ mod tests { type SlowHash = crate::slow_hash::NoOpHash; } - #[derive(Clone, Zeroize)] + type KeCurve = ::KeGroup; + + #[derive(Clone)] struct RemoteKey(PrivateKey); impl SecretKey for RemoteKey { @@ -422,7 +340,7 @@ mod tests { const PASSWORD: &str = "password"; let sk = KeCurve::random_sk(&mut OsRng); - let sk = RemoteKey(PrivateKey(Key(sk))); + let sk = RemoteKey(PrivateKey(sk)); let keypair = KeyPair::from_private_key(sk).unwrap(); let server_setup = ServerSetup::::new_with_key(&mut OsRng, keypair); diff --git a/src/lib.rs b/src/lib.rs index 37645129..41040647 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,7 +34,7 @@ //! struct Default; //! impl CipherSuite for Default { //! type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! type KeGroup = opaque_ke::Ristretto255; //! type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! type Hash = sha2::Sha512; //! type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -60,7 +60,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -68,7 +68,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -111,7 +111,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -119,7 +119,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -151,7 +151,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -159,7 +159,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -200,7 +200,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -208,7 +208,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -250,7 +250,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -258,7 +258,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -305,7 +305,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -313,7 +313,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -344,7 +344,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -352,7 +352,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -409,7 +409,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -417,7 +417,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -467,7 +467,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -475,7 +475,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -565,7 +565,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -573,7 +573,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -664,7 +664,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -672,7 +672,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -747,7 +747,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -755,7 +755,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -797,7 +797,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -805,7 +805,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -858,7 +858,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -866,7 +866,7 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -956,7 +956,7 @@ //! # #[cfg(feature = "ristretto255")] //! # impl CipherSuite for Default { //! # type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; -//! # type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; +//! # type KeGroup = opaque_ke::Ristretto255; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha512; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; @@ -964,18 +964,18 @@ //! # #[cfg(not(feature = "ristretto255"))] //! # impl CipherSuite for Default { //! # type OprfGroup = p256_::ProjectivePoint; -//! # type KeGroup = p256_::PublicKey; +//! # type KeGroup = p256_::NistP256; //! # type KeyExchange = opaque_ke::key_exchange::tripledh::TripleDH; //! # type Hash = sha2::Sha256; //! # type SlowHash = opaque_ke::slow_hash::NoOpHash; //! # } //! # #[derive(Debug)] //! # struct YourRemoteKeyError; -//! # #[derive(Clone, Zeroize)] -//! # struct YourRemoteKey(PrivateKey<::KeGroup>); +//! # #[derive(Clone)] +//! # struct YourRemoteKey(<::KeGroup as KeGroup>::Sk); //! # impl YourRemoteKey { //! # fn diffie_hellman(&self, pk: &[u8]) -> Result::KeGroup as KeGroup>::PkLen>, YourRemoteKeyError> { todo!() } -//! # fn public_key(&self) -> Result::KeGroup as KeGroup>::PkLen>, YourRemoteKeyError> { Ok(GenericArray::default()) } +//! # fn public_key(&self) -> Result::KeGroup as KeGroup>::PkLen>, YourRemoteKeyError> { Ok(<::KeGroup>::serialize_pk(&<::KeGroup>::public_key(&self.0))) } //! # } //! impl SecretKey<::KeGroup> for YourRemoteKey { //! type Error = YourRemoteKeyError; @@ -985,14 +985,13 @@ //! &self, //! pk: PublicKey<::KeGroup>, //! ) -> Result::KeGroup as KeGroup>::PkLen>, InternalError> { -//! YourRemoteKey::diffie_hellman(self, &pk.to_arr()).map_err(InternalError::Custom) +//! YourRemoteKey::diffie_hellman(self, &pk.to_bytes()).map_err(InternalError::Custom) //! } //! //! fn public_key( //! &self //! ) -> Result::KeGroup>, InternalError> { -//! YourRemoteKey::public_key(self).map(PublicKey::from_arr) -//! .map_err(InternalError::Custom) +//! PublicKey::from_bytes(&YourRemoteKey::public_key(self).map_err(InternalError::Custom)?).map_err(InternalError::into_custom) //! } //! //! fn serialize(&self) -> GenericArray { @@ -1006,7 +1005,7 @@ //! } //! } //! -//! # let remote_key = YourRemoteKey(PrivateKey::from_arr(GenericArray::default())); +//! # let remote_key = YourRemoteKey(<::KeGroup>::random_sk(&mut OsRng)); //! let keypair = KeyPair::from_private_key(remote_key).unwrap(); //! let server_setup = ServerSetup::::new_with_key(&mut OsRng, keypair); //! ``` @@ -1032,27 +1031,27 @@ //! `ristretto255_fiat_u32`. Any `ristretto255_*` backend feature will enable //! the `ristretto255` feature, which can be used too, but keep in mind that //! `curve25519-dalek` will fail to compile without a selected backend. This -//! enables the use of `curve25519_dalek::ristretto::RistrettoPoint` as a -//! `KeGroup` and `OprfGroup`. +//! enables the use of [`Ristretto255`] as a `KeGroup` and +//! [`curve25519_dalek::ristretto::RistrettoPoint`] `OprfGroup`. //! //! - The `x25519` feature is similar to the `ristretto255` feature and requires //! to select a backend like `x25519_u64`, other backends are the same as in -//! `ristretto255_*`. This enables `x25519_dalek::PublicKey` as a `KeGroup`. +//! `ristretto255_*`. This enables [`X25519`] as a `KeGroup`. //! //! - The `ristretto255_simd` feature is re-exported from [curve25519-dalek](https://doc.dalek.rs/curve25519_dalek/index.html#backends-and-features) //! and enables parallel formulas, using either AVX2 or AVX512-IFMA. This will //! automatically enable the `ristretto255_u64` feature and requires Rust //! nightly. //! -//! - The `p256` feature enables the use of `p256::PublicKey` as a `KeGroup` and -//! `p256::ProjectivePoint` as a `OprfGroup` for `CipherSuite`. Note that this -//! is currently an experimental feature ⚠️, and is not yet ready for +//! - The `p256` feature enables the use of [`p256::NistP256`] as a `KeGroup` +//! and [`p256::ProjectivePoint`] as a `OprfGroup` for `CipherSuite`. Note +//! that this is currently an experimental feature ⚠️, and is not yet ready for //! production use. //! //! - The `bench` feature is used only for running performance benchmarks for //! this implementation. -#![deny(unsafe_code)] +#![cfg_attr(not(test), deny(unsafe_code))] #![no_std] #![warn(clippy::cargo, missing_docs)] #![allow(clippy::multiple_crate_versions, type_alias_bounds)] @@ -1060,6 +1059,9 @@ #[cfg(any(feature = "std", test))] extern crate std; +#[cfg(feature = "p256")] +extern crate p256_ as p256; + // Error types pub mod errors; @@ -1074,6 +1076,7 @@ mod messages; mod opaque; mod serialization; pub mod slow_hash; +mod util; #[cfg(test)] mod tests; @@ -1083,6 +1086,10 @@ mod tests; pub use ciphersuite::CipherSuite; pub use rand; +#[cfg(feature = "ristretto255")] +pub use crate::key_exchange::group::ristretto255::Ristretto255; +#[cfg(feature = "x25519")] +pub use crate::key_exchange::group::x25519::X25519; pub use crate::messages::{ CredentialFinalization, CredentialFinalizationLen, CredentialRequest, CredentialRequestLen, CredentialResponse, CredentialResponseLen, RegistrationRequest, RegistrationRequestLen, diff --git a/src/messages.rs b/src/messages.rs index c0719977..d94957c6 100755 --- a/src/messages.rs +++ b/src/messages.rs @@ -28,7 +28,7 @@ use crate::key_exchange::traits::{ FromBytes, Ke1MessageLen, Ke2MessageLen, Ke3MessageLen, KeyExchange, ToBytes, }; use crate::key_exchange::tripledh::NonceLen; -use crate::keypair::{KeyPair, PublicKey, SecretKey}; +use crate::keypair::{PublicKey, SecretKey}; use crate::opaque::{MaskedResponse, MaskedResponseLen, ServerSetup}; //////////////////////////// @@ -56,7 +56,7 @@ impl_serialize_and_deserialize_for!(RegistrationRequest); /// registration attempt #[derive(DeriveWhere)] #[derive_where(Clone)] -#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; CS::OprfGroup)] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; CS::OprfGroup, ::Pk)] pub struct RegistrationResponse where ::Core: ProxyHash, @@ -80,7 +80,8 @@ impl_serialize_and_deserialize_for!( /// The final message from the client, containing sealed cryptographic /// identifiers #[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize(drop))] +#[derive_where(Clone, Zeroize(drop))] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; ::Pk)] pub struct RegistrationUpload where ::Core: ProxyHash, @@ -93,6 +94,7 @@ where /// The masking key used to mask the envelope pub(crate) masking_key: Output, /// The user's public key + #[derive_where(skip(Zeroize))] pub(crate) client_s_pk: PublicKey, } @@ -111,7 +113,7 @@ impl_serialize_and_deserialize_for!( /// The message sent by the user to the server, to initiate registration #[derive(DeriveWhere)] -#[derive_where(Clone, Zeroize)] +#[derive_where(Clone, Zeroize(drop))] #[derive_where( Debug, Eq, Hash, PartialEq; CS::OprfGroup, @@ -249,7 +251,7 @@ where self.evaluation_element .value() .to_arr() - .concat(self.server_s_pk.to_arr()) + .concat(self.server_s_pk.to_bytes()) } /// Deserialization from bytes @@ -260,9 +262,7 @@ where check_slice_size(input, elem_len + key_len, "registration_response_bytes")?; // Ensure that public key is valid - let server_s_pk = KeyPair::::check_public_key(PublicKey::from_bytes( - &checked_slice[elem_len..], - )?)?; + let server_s_pk = PublicKey::deserialize(&checked_slice[elem_len..])?; Ok(Self { evaluation_element: voprf::EvaluationElement::deserialize(&checked_slice[..elem_len])?, @@ -304,7 +304,7 @@ where RegistrationUploadLen: ArrayLength, { self.client_s_pk - .to_arr() + .to_bytes() .concat(self.masking_key.clone()) .concat(self.envelope.serialize()) } @@ -321,9 +321,7 @@ where masking_key: GenericArray::clone_from_slice( &checked_slice[key_len..key_len + hash_len], ), - client_s_pk: KeyPair::::check_public_key(PublicKey::from_bytes( - &checked_slice[..key_len], - )?)?, + client_s_pk: PublicKey::deserialize(&checked_slice[..key_len])?, }) } diff --git a/src/opaque.rs b/src/opaque.rs index d3f3911c..8aef2864 100755 --- a/src/opaque.rs +++ b/src/opaque.rs @@ -61,15 +61,17 @@ const STR_OPAQUE_DERIVE_KEY_PAIR: &[u8; 20] = b"OPAQUE-DeriveKeyPair"; derive(serde_::Deserialize, serde_::Serialize), serde( bound( - deserialize = "KeyPair: serde_::Deserialize<'de>", - serialize = "KeyPair: serde_::Serialize" + deserialize = "::Pk: serde_::Deserialize<'de>, ::Sk: serde_::Deserialize<'de>, S: serde_::Deserialize<'de>", + serialize = "::Pk: serde_::Serialize, ::Sk: serde_::Serialize, S: serde_::Serialize" ), crate = "serde_" ) )] #[derive(DeriveWhere)] #[derive_where(Clone)] -#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; S)] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; ::Pk, ::Sk, S)] pub struct ServerSetup< CS: CipherSuite, S: SecretKey = PrivateKey<::KeGroup>, @@ -111,8 +113,9 @@ impl_serialize_and_deserialize_for!( /// The state elements the server holds to record a registration #[derive(DeriveWhere)] -#[derive_where(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, Zeroize(drop))] -pub struct ServerRegistration(RegistrationUpload) +#[derive_where(Clone, Zeroize(drop))] +#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; ::Pk)] +pub struct ServerRegistration(pub(crate) RegistrationUpload) where ::Core: ProxyHash, <::Core as BlockSizeUser>::BlockSize: IsLess, @@ -147,9 +150,9 @@ where <::Core as BlockSizeUser>::BlockSize: IsLess, Le<<::Core as BlockSizeUser>::BlockSize, U256>: NonZero, { - oprf_client: voprf::NonVerifiableClient, - ke1_state: >::KE1State, - credential_request: CredentialRequest, + pub(crate) oprf_client: voprf::NonVerifiableClient, + pub(crate) ke1_state: >::KE1State, + pub(crate) credential_request: CredentialRequest, } impl_serialize_and_deserialize_for!( @@ -242,7 +245,7 @@ where self.oprf_seed .clone() .concat(self.keypair.private().serialize()) - .concat(self.fake_keypair.private().to_arr()) + .concat(self.fake_keypair.private().serialize()) } /// Deserialization from bytes @@ -509,22 +512,6 @@ where ke1_state, }) } - - /// Only used for testing zeroize - #[cfg(test)] - pub(crate) fn to_vec(&self) -> std::vec::Vec - where - // CredentialRequest: KgPk + Ke1Message - ::ElemLen: Add>, - CredentialRequestLen: ArrayLength, - { - [ - self.oprf_client.serialize().to_vec(), - self.credential_request.serialize().to_vec(), - self.ke1_state.to_bytes().to_vec(), - ] - .concat() - } } impl ClientLogin @@ -728,8 +715,8 @@ where let (id_u, id_s) = bytestrings_from_identifiers::( identifiers, - client_s_pk.to_arr(), - server_s_pk.to_arr(), + client_s_pk.to_bytes(), + server_s_pk.to_bytes(), ) .map_err(ProtocolError::into_custom)?; @@ -758,7 +745,7 @@ where rng, credential_request_bytes, credential_response_component, - credential_request.ke1_message, + credential_request.ke1_message.clone(), client_s_pk, server_s_sk.clone(), id_u.iter(), @@ -1184,7 +1171,7 @@ where for (x1, x2) in xor_pad.iter_mut().zip( server_s_pk - .to_arr() + .to_bytes() .as_slice() .iter() .chain(envelope.serialize().iter()), @@ -1221,12 +1208,9 @@ where } let key_len = ::PkLen::USIZE; - let unchecked_server_s_pk = PublicKey::from_bytes(&xor_pad[..key_len])?; - let envelope = Envelope::deserialize(&xor_pad[key_len..])?; - - // Ensure that public key is valid - let server_s_pk = KeyPair::::check_public_key(unchecked_server_s_pk) + let server_s_pk = PublicKey::deserialize(&xor_pad[..key_len]) .map_err(|_| ProtocolError::SerializationError)?; + let envelope = Envelope::deserialize(&xor_pad[key_len..])?; Ok((server_s_pk, envelope)) } diff --git a/src/serialization/mod.rs b/src/serialization/mod.rs index 10f31fd5..81060f90 100644 --- a/src/serialization/mod.rs +++ b/src/serialization/mod.rs @@ -12,7 +12,7 @@ use generic_array::typenum::{U0, U2}; use generic_array::{ArrayLength, GenericArray}; use hmac::Mac; -use crate::errors::ProtocolError; +use crate::errors::{InternalError, ProtocolError}; // Corresponds to the I2OSP() function from RFC8017 pub(crate) fn i2osp>( @@ -157,6 +157,20 @@ impl MacExt for T { } } +pub(crate) trait GenericArrayExt { + fn try_from_slice(slice: &[u8]) -> Result<&Self, InternalError>; +} + +impl> GenericArrayExt for GenericArray { + fn try_from_slice(slice: &[u8]) -> Result<&Self, InternalError> { + if slice.len() == L::USIZE { + Ok(Self::from_slice(slice)) + } else { + Err(InternalError::InvalidByteSequence) + } + } +} + #[cfg(test)] mod tests; diff --git a/src/serialization/tests.rs b/src/serialization/tests.rs index 2a2e3e9c..02364ba0 100755 --- a/src/serialization/tests.rs +++ b/src/serialization/tests.rs @@ -28,7 +28,7 @@ use crate::key_exchange::traits::{ FromBytes, Ke1MessageLen, Ke1StateLen, Ke2MessageLen, KeyExchange, ToBytes, }; use crate::key_exchange::tripledh::{NonceLen, TripleDH}; -use crate::keypair::KeyPair; +use crate::keypair::{KeyPair, SecretKey}; use crate::messages::CredentialResponseWithoutKeLen; use crate::opaque::{ClientLoginLen, ClientRegistrationLen, MaskedResponseLen}; use crate::serialization::{i2osp, os2ip}; @@ -39,7 +39,7 @@ struct Ristretto255; #[cfg(feature = "ristretto255")] impl CipherSuite for Ristretto255 { type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; - type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; + type KeGroup = crate::Ristretto255; type KeyExchange = TripleDH; type Hash = sha2::Sha512; type SlowHash = crate::slow_hash::NoOpHash; @@ -49,14 +49,14 @@ impl CipherSuite for Ristretto255 { struct P256; #[cfg(feature = "p256")] impl CipherSuite for P256 { - type OprfGroup = p256_::ProjectivePoint; - type KeGroup = p256_::PublicKey; + type OprfGroup = ::p256::ProjectivePoint; + type KeGroup = ::p256::NistP256; type KeyExchange = TripleDH; type Hash = sha2::Sha256; type SlowHash = crate::slow_hash::NoOpHash; } -fn random_point() -> CS::KeGroup +fn random_point() -> ::Pk where ::Core: ProxyHash, <::Core as BlockSizeUser>::BlockSize: IsLess, @@ -140,7 +140,7 @@ fn server_registration_roundtrip() -> Result<(), ProtocolError> { let mock_client_kp = KeyPair::::generate_random(&mut rng); // serialization order: oprf_key, public key, envelope let mut bytes = Vec::::new(); - bytes.extend_from_slice(&mock_client_kp.public().to_arr()); + bytes.extend_from_slice(&mock_client_kp.public().to_bytes()); bytes.extend_from_slice(&masking_key); bytes.extend_from_slice(&mock_envelope_bytes); let reg = ServerRegistration::::deserialize(&bytes)?; @@ -166,7 +166,7 @@ fn registration_request_roundtrip() -> Result<(), ProtocolError> { Le<<::Core as BlockSizeUser>::BlockSize, U256>: NonZero, { let pt = random_point::(); - let pt_bytes = pt.to_arr().to_vec(); + let pt_bytes = CS::KeGroup::serialize_pk(&pt); let mut input = Vec::new(); input.extend_from_slice(&pt_bytes); @@ -209,10 +209,10 @@ fn registration_response_roundtrip() -> Result<(), ProtocolError> { RegistrationResponseLen: ArrayLength, { let pt = random_point::(); - let beta_bytes = pt.to_arr(); + let beta_bytes = CS::KeGroup::serialize_pk(&pt); let mut rng = OsRng; let skp = KeyPair::::generate_random(&mut rng); - let pubkey_bytes = skp.public().to_arr(); + let pubkey_bytes = skp.public().to_bytes(); let mut input = Vec::new(); input.extend_from_slice(&beta_bytes); @@ -264,7 +264,7 @@ fn registration_upload_roundtrip() -> Result<(), ProtocolError> { { let mut rng = OsRng; let skp = KeyPair::::generate_random(&mut rng); - let pubkey_bytes = skp.public().to_arr(); + let pubkey_bytes = skp.public().to_bytes(); let mut key = [0u8; 32]; rng.fill_bytes(&mut key); @@ -318,13 +318,17 @@ fn credential_request_roundtrip() -> Result<(), ProtocolError> { { let mut rng = OsRng; let alpha = random_point::(); - let alpha_bytes = alpha.to_arr(); + let alpha_bytes = CS::KeGroup::serialize_pk(&alpha); let client_e_kp = KeyPair::::generate_random(&mut rng); let mut client_nonce = [0u8; NonceLen::USIZE]; rng.fill_bytes(&mut client_nonce); - let ke1m: Vec = [client_nonce.as_ref(), client_e_kp.public()].concat(); + let ke1m: Vec = [ + client_nonce.as_ref(), + client_e_kp.public().to_bytes().as_ref(), + ] + .concat(); let mut input = Vec::new(); input.extend_from_slice(&alpha_bytes); @@ -377,7 +381,7 @@ fn credential_response_roundtrip() -> Result<(), ProtocolError> { CredentialResponseLen: ArrayLength, { let pt = random_point::(); - let pt_bytes = pt.to_arr(); + let pt_bytes = CS::KeGroup::serialize_pk(&pt); let mut rng = OsRng; @@ -394,7 +398,12 @@ fn credential_response_roundtrip() -> Result<(), ProtocolError> { let mut server_nonce = [0u8; NonceLen::USIZE]; rng.fill_bytes(&mut server_nonce); - let ke2m: Vec = [server_nonce.as_ref(), server_e_kp.public(), &mac].concat(); + let ke2m: Vec = [ + server_nonce.as_ref(), + server_e_kp.public().to_bytes().as_ref(), + &mac, + ] + .concat(); let mut input = Vec::new(); input.extend_from_slice(&pt_bytes); @@ -489,7 +498,7 @@ fn client_login_roundtrip() -> Result<(), ProtocolError> { rng.fill_bytes(&mut client_nonce); let l1_data = [ - client_e_kp.private().to_arr().to_vec(), + client_e_kp.private().serialize().to_vec(), client_nonce.to_vec(), ] .concat(); @@ -501,7 +510,11 @@ fn client_login_roundtrip() -> Result<(), ProtocolError> { blinded_element: blind_result.message, ke1_message: >::KE1Message::from_bytes( - &[client_nonce.as_ref(), client_e_kp.public()].concat(), + &[ + client_nonce.as_ref(), + client_e_kp.public().to_bytes().as_ref(), + ] + .concat(), )?, }; @@ -541,7 +554,11 @@ fn ke1_message_roundtrip() -> Result<(), ProtocolError> { let mut client_nonce = vec![0u8; NonceLen::USIZE]; rng.fill_bytes(&mut client_nonce); - let ke1m = [client_nonce.as_slice(), client_e_kp.public()].concat(); + let ke1m = [ + client_nonce.as_slice(), + client_e_kp.public().to_bytes().as_ref(), + ] + .concat(); let reg = >::KE1Message::from_bytes(&ke1m)?; let reg_bytes = reg.to_bytes(); @@ -574,7 +591,12 @@ fn ke2_message_roundtrip() -> Result<(), ProtocolError> { let mut server_nonce = vec![0u8; NonceLen::USIZE]; rng.fill_bytes(&mut server_nonce); - let ke2m: Vec = [server_nonce.as_slice(), server_e_kp.public(), &mac].concat(); + let ke2m: Vec = [ + server_nonce.as_slice(), + server_e_kp.public().to_bytes().as_ref(), + &mac, + ] + .concat(); let reg = >::KE2Message::from_bytes(&ke2m)?; diff --git a/src/tests/full_test.rs b/src/tests/full_test.rs index 5644f679..d2206ad3 100755 --- a/src/tests/full_test.rs +++ b/src/tests/full_test.rs @@ -29,6 +29,7 @@ use crate::hash::{OutputSize, ProxyHash}; use crate::key_exchange::group::KeGroup; use crate::key_exchange::traits::{Ke1MessageLen, Ke1StateLen, Ke2MessageLen}; use crate::key_exchange::tripledh::{NonceLen, TripleDH}; +use crate::keypair::SecretKey; use crate::messages::{ CredentialRequestLen, CredentialResponseLen, CredentialResponseWithoutKeLen, RegistrationResponseLen, RegistrationUploadLen, @@ -46,7 +47,7 @@ struct Ristretto255; #[cfg(feature = "ristretto255")] impl CipherSuite for Ristretto255 { type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; - type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; + type KeGroup = crate::Ristretto255; type KeyExchange = TripleDH; type Hash = sha2::Sha512; type SlowHash = NoOpHash; @@ -56,8 +57,8 @@ impl CipherSuite for Ristretto255 { struct P256; #[cfg(feature = "p256")] impl CipherSuite for P256 { - type OprfGroup = p256_::ProjectivePoint; - type KeGroup = p256_::PublicKey; + type OprfGroup = p256::ProjectivePoint; + type KeGroup = p256::NistP256; type KeyExchange = TripleDH; type Hash = sha2::Sha256; type SlowHash = NoOpHash; @@ -68,7 +69,7 @@ struct X25519Ristretto255; #[cfg(all(feature = "x25519", feature = "ristretto255"))] impl CipherSuite for X25519Ristretto255 { type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; - type KeGroup = x25519_dalek::PublicKey; + type KeGroup = crate::X25519; type KeyExchange = TripleDH; type Hash = sha2::Sha512; type SlowHash = NoOpHash; @@ -78,8 +79,8 @@ impl CipherSuite for X25519Ristretto255 { struct X25519P256; #[cfg(all(feature = "x25519", feature = "p256"))] impl CipherSuite for X25519P256 { - type OprfGroup = p256_::ProjectivePoint; - type KeGroup = x25519_dalek::PublicKey; + type OprfGroup = p256::ProjectivePoint; + type KeGroup = crate::X25519; type KeyExchange = TripleDH; type Hash = sha2::Sha256; type SlowHash = NoOpHash; @@ -517,11 +518,11 @@ where let mut server_nonce = [0u8; NonceLen::USIZE]; rng.fill_bytes(&mut server_nonce); - let fake_sk: Vec = fake_kp.private().to_vec(); + let fake_sk: Vec = fake_kp.private().serialize().to_vec(); let server_setup = ServerSetup::::deserialize( &[ oprf_seed.as_ref(), - &server_s_kp.private().to_arr(), + &server_s_kp.private().serialize(), &fake_sk, ] .concat(), @@ -557,7 +558,7 @@ where let registration_response_bytes = server_registration_start_result.message.serialize(); let mut client_s_sk_and_nonce: Vec = Vec::new(); - client_s_sk_and_nonce.extend_from_slice(&client_s_kp.private().to_arr()); + client_s_sk_and_nonce.extend_from_slice(&client_s_kp.private().serialize()); client_s_sk_and_nonce.extend_from_slice(&envelope_nonce); let mut finish_registration_rng = CycleRng::new(client_s_sk_and_nonce); @@ -583,7 +584,7 @@ where let mut client_login_start: Vec = Vec::new(); client_login_start.extend_from_slice(&blinding_factor_bytes); - client_login_start.extend_from_slice(&client_e_kp.private().to_arr()); + client_login_start.extend_from_slice(&client_e_kp.private().serialize()); client_login_start.extend_from_slice(&client_nonce); let mut client_login_start_rng = CycleRng::new(client_login_start); @@ -595,7 +596,7 @@ where let mut server_e_sk_and_nonce_rng = CycleRng::new( [ masking_nonce.to_vec(), - server_e_kp.private().to_arr().to_vec(), + server_e_kp.private().serialize().to_vec(), server_nonce.to_vec(), ] .concat(), @@ -636,14 +637,14 @@ where let credential_finalization_bytes = client_login_finish_result.message.serialize(); Ok(TestVectorParameters { - client_s_pk: client_s_kp.public().to_arr().to_vec(), - client_s_sk: client_s_kp.private().to_arr().to_vec(), - client_e_pk: client_e_kp.public().to_arr().to_vec(), - client_e_sk: client_e_kp.private().to_arr().to_vec(), - server_s_pk: server_s_kp.public().to_arr().to_vec(), - server_s_sk: server_s_kp.private().to_arr().to_vec(), - server_e_pk: server_e_kp.public().to_arr().to_vec(), - server_e_sk: server_e_kp.private().to_arr().to_vec(), + client_s_pk: client_s_kp.public().to_bytes().to_vec(), + client_s_sk: client_s_kp.private().serialize().to_vec(), + client_e_pk: client_e_kp.public().to_bytes().to_vec(), + client_e_sk: client_e_kp.private().serialize().to_vec(), + server_s_pk: server_s_kp.public().to_bytes().to_vec(), + server_s_sk: server_s_kp.private().serialize().to_vec(), + server_e_pk: server_e_kp.public().to_bytes().to_vec(), + server_e_sk: server_e_kp.private().serialize().to_vec(), fake_sk, credential_identifier: credential_identifier.to_vec(), id_u: id_u.to_vec(), @@ -1090,7 +1091,7 @@ fn test_credential_finalization() -> Result<(), ProtocolError> { assert_eq!( hex::encode(¶meters.server_s_pk), - hex::encode(&client_login_finish_result.server_s_pk.to_arr().to_vec()) + hex::encode(&client_login_finish_result.server_s_pk.to_bytes().to_vec()) ); assert_eq!( hex::encode(¶meters.session_key), @@ -1371,10 +1372,9 @@ fn test_zeroize_server_registration_finish() -> Result<(), ProtocolError> { let p_file = ServerRegistration::finish(client_registration_finish_result.message); let mut state = p_file; - Zeroize::zeroize(&mut state); - for byte in state.serialize() { - assert_eq!(byte, 0); - } + util::drop_manually(&mut state); + util::test_zeroized(&mut state.0.envelope.mode); + util::test_zeroized(&mut state.0.masking_key); Ok(()) } @@ -1393,24 +1393,39 @@ fn test_zeroize_server_registration_finish() -> Result<(), ProtocolError> { #[test] fn test_zeroize_client_login_start() -> Result<(), ProtocolError> { - fn inner() -> Result<(), ProtocolError> + fn inner>() -> Result<(), ProtocolError> where ::Core: ProxyHash, <::Core as BlockSizeUser>::BlockSize: IsLess, Le<<::Core as BlockSizeUser>::BlockSize, U256>: NonZero, // CredentialRequest: KgPk + Ke1Message - ::ElemLen: Add>, + ::ElemLen: Add::PkLen>>, CredentialRequestLen: ArrayLength, + // Ke1State: KeSk + Nonce + ::SkLen: Add, + Sum<::SkLen, NonceLen>: ArrayLength, + // Ke1Message: Nonce + KePk + NonceLen: Add<::PkLen>, + Sum::PkLen>: ArrayLength, + // Ke2State: (Hash + Hash) + Hash + OutputSize: Add>, + Sum, OutputSize>: + ArrayLength + Add>, + Sum, OutputSize>, OutputSize>: ArrayLength, + // Ke2Message: (Nonce + KePk) + Hash + NonceLen: Add<::PkLen>, + Sum::PkLen>: ArrayLength + Add>, + Sum::PkLen>, OutputSize>: ArrayLength, { let mut client_rng = OsRng; let client_login_start_result = ClientLogin::::start(&mut client_rng, STR_PASSWORD.as_bytes())?; let mut state = client_login_start_result.state; - Zeroize::zeroize(&mut state); - for byte in state.to_vec() { - assert_eq!(byte, 0); - } + util::drop_manually(&mut state); + util::test_zeroized(&mut state.oprf_client); + util::test_zeroized(&mut state.ke1_state); + util::test_zeroized(&mut state.credential_request.ke1_message.client_nonce); Ok(()) } @@ -1490,7 +1505,7 @@ fn test_zeroize_server_login_start() -> Result<(), ProtocolError> { #[test] fn test_zeroize_client_login_finish() -> Result<(), ProtocolError> { - fn inner() -> Result<(), ProtocolError> + fn inner>() -> Result<(), ProtocolError> where ::Core: ProxyHash, <::Core as BlockSizeUser>::BlockSize: IsLess, @@ -1500,8 +1515,23 @@ fn test_zeroize_client_login_finish() -> Result<(), ProtocolError> { Sum>: ArrayLength + Add<::PkLen>, MaskedResponseLen: ArrayLength, // CredentialRequest: KgPk + Ke1Message - ::ElemLen: Add>, + ::ElemLen: Add::PkLen>>, CredentialRequestLen: ArrayLength, + // Ke1State: KeSk + Nonce + ::SkLen: Add, + Sum<::SkLen, NonceLen>: ArrayLength, + // Ke1Message: Nonce + KePk + NonceLen: Add<::PkLen>, + Sum::PkLen>: ArrayLength, + // Ke2State: (Hash + Hash) + Hash + OutputSize: Add>, + Sum, OutputSize>: + ArrayLength + Add>, + Sum, OutputSize>, OutputSize>: ArrayLength, + // Ke2Message: (Nonce + KePk) + Hash + NonceLen: Add<::PkLen>, + Sum::PkLen>: ArrayLength + Add>, + Sum::PkLen>, OutputSize>: ArrayLength, { let mut client_rng = OsRng; let mut server_rng = OsRng; @@ -1537,10 +1567,10 @@ fn test_zeroize_client_login_finish() -> Result<(), ProtocolError> { )?; let mut state = client_login_finish_result.state; - Zeroize::zeroize(&mut state); - for byte in state.to_vec() { - assert_eq!(byte, 0); - } + util::drop_manually(&mut state); + util::test_zeroized(&mut state.oprf_client); + util::test_zeroized(&mut state.ke1_state); + util::test_zeroized(&mut state.credential_request.ke1_message.client_nonce); Ok(()) } diff --git a/src/tests/test_opaque_vectors.rs b/src/tests/test_opaque_vectors.rs index 7c2e9a0a..de1b882b 100755 --- a/src/tests/test_opaque_vectors.rs +++ b/src/tests/test_opaque_vectors.rs @@ -207,7 +207,7 @@ fn tests() -> Result<(), ProtocolError> { struct Ristretto255Sha512NoSlowHash; impl CipherSuite for Ristretto255Sha512NoSlowHash { type OprfGroup = curve25519_dalek::ristretto::RistrettoPoint; - type KeGroup = curve25519_dalek::ristretto::RistrettoPoint; + type KeGroup = crate::Ristretto255; type KeyExchange = TripleDH; type Hash = sha2::Sha512; type SlowHash = NoOpHash; @@ -237,8 +237,8 @@ fn tests() -> Result<(), ProtocolError> { struct P256Sha256NoSlowHash; impl CipherSuite for P256Sha256NoSlowHash { - type OprfGroup = p256_::ProjectivePoint; - type KeGroup = p256_::PublicKey; + type OprfGroup = p256::ProjectivePoint; + type KeGroup = p256::NistP256; type KeyExchange = TripleDH; type Hash = sha2::Sha256; type SlowHash = NoOpHash; diff --git a/src/util.rs b/src/util.rs new file mode 100644 index 00000000..2aaff7d9 --- /dev/null +++ b/src/util.rs @@ -0,0 +1,41 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// +// This source code is licensed under both the MIT license found in the +// LICENSE-MIT file in the root directory of this source tree and the Apache +// License, Version 2.0 found in the LICENSE-APACHE file in the root directory +// of this source tree. + +//! Utility functions. + +#[cfg(test)] +pub(crate) fn test_zeroize_on_drop(value: &mut T) { + drop_manually(value); + + test_zeroized(value); +} + +#[cfg(test)] +pub(crate) fn test_zeroized(value: &mut T) { + use std::{mem, slice, vec}; + + let test = + unsafe { slice::from_raw_parts(value as *const _ as *const u8, mem::size_of::()) }; + + assert_eq!(test, vec![0; mem::size_of::()]); +} + +#[cfg(test)] +pub(crate) fn drop_manually(value: &mut T) { + use std::{mem, ptr, vec}; + + assert!(mem::needs_drop::()); + let mut test_holder = vec![value]; + let ptr = &mut *test_holder[0] as *mut T; + + unsafe { + test_holder.set_len(0); + ptr::drop_in_place(ptr); + } + + assert_eq!(test_holder.capacity(), 1); +}