From 38340047ce540a5f30ced690d3894c8ae8eff34c Mon Sep 17 00:00:00 2001 From: eschorn1 Date: Mon, 1 Apr 2024 11:07:43 -0500 Subject: [PATCH] code opts 2 --- fuzz/README.md | 2 +- src/byte_fns.rs | 18 ++++++++++-------- src/helpers.rs | 10 +++++----- src/k_pke.rs | 24 +++++++++++------------- src/lib.rs | 34 ++++++++++++++++++++++++++++------ src/ml_kem.rs | 18 +++++++++--------- src/ntt.rs | 6 ++---- src/sampling.rs | 2 +- src/traits.rs | 33 +++++++++++++++++++-------------- src/types.rs | 13 +++++++------ tests/fails.rs | 13 ++++++------- tests/native.rs | 4 ++-- 12 files changed, 101 insertions(+), 76 deletions(-) diff --git a/fuzz/README.md b/fuzz/README.md index 1a0baf6..32c764d 100644 --- a/fuzz/README.md +++ b/fuzz/README.md @@ -18,7 +18,7 @@ $ cargo fuzz run fuzz_all -j 4 Coverage status of ml_kem_512 is robust, see: ~~~ -#7851: cov: 6312 ft: 3969 corp: 26 exec/s 4 oom/timeout/crash: 0/0/0 time: 843s job: 55 dft_time: 0 +#3543: cov: 6156 ft: 4187 corp: 31 exec/s 5 oom/timeout/crash: 0/0/0 time: 170s job: 33 dft_time: 0 # Warning: the following tools are tricky to install/configure $ cargo install cargo-cov diff --git a/src/byte_fns.rs b/src/byte_fns.rs index 80dc94d..9d5322d 100644 --- a/src/byte_fns.rs +++ b/src/byte_fns.rs @@ -25,10 +25,13 @@ use crate::Q; /// Input: integer array `F ∈ Z^{256}_m`, where `m = 2^d if d < 12` and `m = q if d = 12`
/// Output: byte array B ∈ B^{32·d} pub(crate) fn byte_encode(d: u32, integers_f: &[Z; 256], bytes_b: &mut [u8]) { - debug_assert_eq!(bytes_b.len(), 32 * d as usize, "Alg 4: bytes len is not 32 * d"); - debug_assert!(integers_f - .iter() - .all(|f| f.get_u16() <= if d < 12 { 1 << d } else { Q })); + debug_assert_eq!(bytes_b.len(), 32 * d as usize, "Alg 4: bytes_b len is not 32 * d"); + debug_assert!( + integers_f + .iter() + .all(|f| f.get_u16() <= if d < 12 { 1 << d } else { Q }), + "Alg 4: integers_f out of range" + ); // // Our "working" register, from which to drop bytes out of let mut temp = 0u32; @@ -47,11 +50,10 @@ pub(crate) fn byte_encode(d: u32, integers_f: &[Z; 256], bytes_b: &mut [u8]) { bit_index += d as usize; // While we have enough bits to drop a byte, do so - #[allow(clippy::cast_possible_truncation)] // Intentional truncation while bit_index > 7 { // // Drop the byte - bytes_b[byte_index] = temp as u8; + bytes_b[byte_index] = temp.to_le_bytes()[0]; // Update the indices temp >>= 8; @@ -87,7 +89,7 @@ pub(crate) fn byte_decode( bit_index += 8; // If we have enough bits to drop an int, do so - #[allow(clippy::cast_possible_truncation)] // Intentional truncation + #[allow(clippy::cast_possible_truncation)] // Intentional truncation, temp as u16 while bit_index >= d { // // Mask off the upper portion and drop it in @@ -103,7 +105,7 @@ pub(crate) fn byte_decode( } let m = if d < 12 { 1 << d } else { Q }; - ensure!(integers_f.iter().all(|e| e.get_u16() < m), "Alg5: integers out of range"); + ensure!(integers_f.iter().all(|e| e.get_u16() < m), "Alg 5: integers out of range"); Ok(()) } diff --git a/src/helpers.rs b/src/helpers.rs index 033c5f8..c98c592 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -39,7 +39,7 @@ pub(crate) fn mul_mat_vec( ) -> [[Z; 256]; K] { let mut w_hat = [[Z::default(); 256]; K]; for i in 0..K { - #[allow(clippy::needless_range_loop)] + #[allow(clippy::needless_range_loop)] // alternative is harder to understand for j in 0..K { let tmp = multiply_ntts(&a_hat[i][j], &u_hat[j]); for k in 0..256 { @@ -57,9 +57,9 @@ pub(crate) fn mul_mat_t_vec( a_hat: &[[[Z; 256]; K]; K], u_hat: &[[Z; 256]; K], ) -> [[Z; 256]; K] { let mut y_hat = [[Z::default(); 256]; K]; - #[allow(clippy::needless_range_loop)] + #[allow(clippy::needless_range_loop)] // alternative is harder to understand for i in 0..K { - #[allow(clippy::needless_range_loop)] + #[allow(clippy::needless_range_loop)] // alternative is harder to understand for j in 0..K { let tmp = multiply_ntts(&a_hat[j][i], &u_hat[j]); for k in 0..256 { @@ -152,7 +152,7 @@ pub(crate) fn compress(d: u32, inout: &mut [Z]) { // Barrett constants should be resolved at compile time let q64 = u64::from(Q); let k = 32; - let m = 2u64.pow(k) / q64; + let m = (1 << k) / q64; // Barrett division, quotient could be too small by one let top = u64::from(x_ref.get_u32()) << d; let quot = (top * m) >> k; @@ -171,7 +171,7 @@ pub(crate) fn compress(d: u32, inout: &mut [Z]) { #[allow(clippy::cast_possible_truncation)] // last line pub(crate) fn decompress(d: u32, inout: &mut [Z]) { for y_ref in &mut *inout { - let qy = u32::from(Q) * y_ref.get_u32() + 2u32.pow(d) - 1; + let qy = u32::from(Q) * y_ref.get_u32() + (1 << d) - 1; y_ref.set_u16((qy >> d) as u16); } } diff --git a/src/k_pke.rs b/src/k_pke.rs index 23644be..7aaeac8 100644 --- a/src/k_pke.rs +++ b/src/k_pke.rs @@ -13,7 +13,7 @@ use crate::types::Z; /// /// Output: encryption key `ekPKE ∈ B^{384·k+32}`
/// Output: decryption key `dkPKE ∈ B^{384·k}` -#[allow(clippy::similar_names, clippy::module_name_repetitions)] +#[allow(clippy::similar_names)] pub(crate) fn k_pke_key_gen( rng: &mut impl CryptoRngCore, eta1: u32, ek_pke: &mut [u8], dk_pke: &mut [u8], ) -> Result<(), &'static str> { @@ -36,12 +36,11 @@ pub(crate) fn k_pke_key_gen( for (i, row) in a_hat.iter_mut().enumerate().take(K) { // // 5: for (j ← 0; j < k; j++) - #[allow(clippy::cast_possible_truncation)] // i and j as u8 for (j, entry) in row.iter_mut().enumerate().take(K) { // // 6: A_hat[i, j] ← SampleNTT(XOF(ρ, i, j)) ▷ each entry of  uniform in NTT domain // See page 21 regarding transpose of i, j -? j, i in XOF() https://csrc.nist.gov/files/pubs/fips/203/ipd/docs/fips-203-initial-public-comments-2023.pdf - *entry = sample_ntt(xof(&rho, j as u8, i as u8))?; + *entry = sample_ntt(xof(&rho, j.to_le_bytes()[0], i.to_le_bytes()[0]))?; // 7: end for } @@ -118,10 +117,10 @@ pub(crate) fn k_pke_encrypt Result<(), &'static str> { - debug_assert_eq!(ek.len(), 384 * K + 32, "Alg13: ek len not 384 * K + 32"); - debug_assert_eq!(m.len(), 32, "Alg13: m len not 32"); - debug_assert_eq!(eta1 as usize * 64, ETA1_64, "Alg13: eta1 size mismatch"); - debug_assert_eq!(eta2 as usize * 64, ETA2_64, "Alg13: eta2 size mismatch"); + debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 13: ek len not 384 * K + 32"); + debug_assert_eq!(m.len(), 32, "Alg 13: m len not 32"); + debug_assert_eq!(eta1 as usize * 64, ETA1_64, "Alg 13: eta1 size mismatch"); + debug_assert_eq!(eta2 as usize * 64, ETA2_64, "Alg 13: eta2 size mismatch"); // 1: N ← 0 let mut n = 0; @@ -132,7 +131,7 @@ pub(crate) fn k_pke_encrypt( du: u32, dv: u32, dk: &[u8], ct: &[u8], ) -> Result<[u8; 32], &'static str> { - debug_assert_eq!(dk.len(), 384 * K, "Alg14: dk len not 384 * K"); + debug_assert_eq!(dk.len(), 384 * K, "Alg 14: dk len not 384 * K"); debug_assert_eq!( ct.len(), 32 * (du as usize * K + dv as usize), - "Alg14: ct len not 32 * (DU * K + DV)" + "Alg 14: ct len not 32 * (DU * K + DV)" ); // 1: c1 ← c[0 : 32du k] diff --git a/src/lib.rs b/src/lib.rs index 895d7f9..62b55c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,12 @@ // Implements FIPS 203 draft Module-Lattice-based Key-Encapsulation Mechanism Standard. // See +// TODO Roadmap +// 0. Stay current with FIPS 203 updates +// 1. Perf: optimize/minimize modular reductions, minimize u16 arith, consider avx2/aarch64 +// (currently, code is 'optimized' for safety and change-support, with reasonable perf) +// 2. Slightly more intelligent fuzzing (as dk contains h(ek)) + // Functionality map per FIPS 203 draft // // Algorithm 2 BitsToBytes(b) on page 17 --> optimized out (byte_fns.rs) @@ -66,6 +72,7 @@ pub mod traits; const Q: u16 = 3329; const ZETA: u16 = 17; + /// Shared Secret Key length for all ML-KEM variants (in bytes) pub const SSK_LEN: usize = 32; @@ -103,7 +110,7 @@ impl PartialEq for SharedSecretKey { macro_rules! functionality { () => { use crate::byte_fns::byte_decode; - use crate::helpers::h; + use crate::helpers::{ensure, h}; use crate::ml_kem::{ml_kem_decaps, ml_kem_encaps, ml_kem_key_gen}; use crate::traits::{Decaps, Encaps, KeyGen, SerDes}; use crate::types::Z; @@ -139,9 +146,12 @@ macro_rules! functionality { } fn validate_keypair_vt(ek: &Self::EncapsByteArray, dk: &Self::DecapsByteArray) -> bool { + // Note that size is checked by only accepting ref to correctly sized byte array let len_ek_pke = 384 * K + 32; let len_dk_pke = 384 * K; + // dk should contain ek let same_ek = (*ek == dk[len_dk_pke..(len_dk_pke + len_ek_pke)]); + // dk should contain hash of ek let same_h = (h(ek) == dk[(len_dk_pke + len_ek_pke)..(len_dk_pke + len_ek_pke + 32)]); same_ek & same_h @@ -208,8 +218,17 @@ macro_rules! functionality { fn try_from_bytes(dk: Self::ByteArray) -> Result { // Validation per pg 31. Note that the two checks specify fixed sizes, and these - // functions take only byte arrays of correct size. Nonetheless, we use a Result - // here in case future opportunities for further validation arise. + // functions take only byte arrays of correct size. Nonetheless, we take the + // opportunity to validate the ek and h(ek). + let len_ek_pke = 384 * K + 32; + let len_dk_pke = 384 * K; + let ek = &dk[len_dk_pke..len_dk_pke + EK_LEN]; + let _res = + EncapsKey::try_from_bytes(ek.try_into().map_err(|_| "Malformed encaps key")?)?; + ensure!( + h(ek) == dk[(len_dk_pke + len_ek_pke)..(len_dk_pke + len_ek_pke + 32)], + "Encaps hash wrong" + ); Ok(DecapsKey { 0: dk }) } } @@ -232,17 +251,20 @@ macro_rules! functionality { #[cfg(test)] mod tests { use super::*; + use crate::types::EncapsKey; use rand_chacha::rand_core::SeedableRng; #[test] fn smoke_test() { let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(123); - for _i in 0..10 { + for _i in 0..100 { let (ek, dk) = KG::try_keygen_with_rng_vt(&mut rng).unwrap(); let (ssk1, ct) = ek.try_encaps_with_rng_vt(&mut rng).unwrap(); let ssk2 = dk.try_decaps_vt(&ct).unwrap(); - assert!(KG::validate_keypair_vt(&ek.into_bytes(), &dk.into_bytes())); - assert_eq!(ssk1, ssk2, "Shared secrets differ"); + assert!(KG::validate_keypair_vt(&ek.clone().into_bytes(), &dk.into_bytes())); + assert_eq!(ssk1, ssk2); + assert_eq!(ek.clone().0, EncapsKey::try_from_bytes(ek.into_bytes()).unwrap().0); + // the other SerDes routines don't really have logic... } } } diff --git a/src/ml_kem.rs b/src/ml_kem.rs index 0b79fdd..be095ba 100644 --- a/src/ml_kem.rs +++ b/src/ml_kem.rs @@ -1,5 +1,5 @@ use crate::byte_fns::{byte_decode, byte_encode}; -use crate::helpers::{ensure, g, h, j}; +use crate::helpers::{g, h, j}; use crate::k_pke::{k_pke_decrypt, k_pke_encrypt, k_pke_key_gen}; use crate::types::Z; use crate::SharedSecretKey; @@ -13,13 +13,13 @@ use rand_core::CryptoRngCore; pub(crate) fn ml_kem_key_gen( rng: &mut impl CryptoRngCore, eta1: u32, ek: &mut [u8], dk: &mut [u8], ) -> Result<(), &'static str> { - debug_assert_eq!(ek.len(), 384 * K + 32, "Alg15: ek len not 384 * K + 32"); - debug_assert_eq!(dk.len(), 768 * K + 96, "Alg15: dk len not 768 * K + 96"); + debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 15: ek len not 384 * K + 32"); + debug_assert_eq!(dk.len(), 768 * K + 96, "Alg 15: dk len not 768 * K + 96"); // 1: z ←− B32 ▷ z is 32 random bytes (see Section 3.3) let mut z = [0u8; 32]; rng.try_fill_bytes(&mut z) - .map_err(|_| "Alg15: Random number generator failed")?; + .map_err(|_| "Alg 15: Random number generator failed")?; // 2: (ek_{PKE}, dk_{PKE}) ← K-PKE.KeyGen() ▷ run key generation for K-PKE let p1 = 384 * K; @@ -47,11 +47,11 @@ pub(crate) fn ml_kem_key_gen( pub(crate) fn ml_kem_encaps( rng: &mut impl CryptoRngCore, du: u32, dv: u32, eta1: u32, eta2: u32, ek: &[u8], ct: &mut [u8], ) -> Result { - debug_assert_eq!(ek.len(), 384 * K + 32, "Alg16: ek len not 384 * K + 32"); // also: size check at top level + debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 16: ek len not 384 * K + 32"); // also: size check at top level debug_assert_eq!( ct.len(), 32 * (du as usize * K + dv as usize), - "Alg16: ct len not 32*(DU*K+DV)" + "Alg 16: ct len not 32*(DU*K+DV)" ); // also: size check at top level // modulus check: perform the computation ek ← ByteEncode12(ByteDecode12(ek_tidle) @@ -68,7 +68,7 @@ pub(crate) fn ml_kem_encaps Result { // These length checks are a bit redundant...but present for completeness and paranoia - ensure!(ct.len() == 32 * (du as usize * K + dv as usize), "Alg17: ct len not 32 * ..."); + debug_assert_eq!(ct.len(), 32 * (du as usize * K + dv as usize), "Alg17: ct len not 32 * ..."); // Ciphertext type check - ensure!(dk.len() == 768 * K + 96, "Alg17: dk len not 768 ..."); // Decapsulation key type check + debug_assert_eq!(dk.len(), 768 * K + 96, "Alg17: dk len not 768 ..."); // Decapsulation key type check // 1019 For some applications, further validation of the decapsulation key dk_tilde may be appropriate. For // 1020 instance, in cases where dk_tilde was generated by a third party, users may want to ensure that the four diff --git a/src/ntt.rs b/src/ntt.rs index 59c783e..c6cbf61 100644 --- a/src/ntt.rs +++ b/src/ntt.rs @@ -121,7 +121,6 @@ pub(crate) fn ntt_inv(f_hat: &[Z; 256]) -> [Z; 256] { /// Input: Two arrays `f_hat` ∈ `Z^{256}_q` and `g_hat` ∈ `Z^{256}_q` ▷ the coefficients of two NTT representations
/// Output: An array `h_hat` ∈ `Z^{256}_q` ▷ the coefficients of the product of the inputs #[must_use] -#[allow(clippy::cast_possible_truncation)] pub(crate) fn multiply_ntts(f_hat: &[Z; 256], g_hat: &[Z; 256]) -> [Z; 256] { let mut h_hat: [Z; 256] = [Z::default(); 256]; @@ -168,7 +167,7 @@ pub(crate) fn base_case_multiply(a0: Z, a1: Z, b0: Z, b1: Z, gamma: Z) -> (Z, Z) /// HAC Algorithm 14.76 Right-to-left binary exponentiation mod Q. #[must_use] -#[allow(clippy::cast_possible_truncation)] // on result +#[allow(clippy::cast_possible_truncation)] // on result as u16 (try_from not const) const fn pow_mod_q(g: u16, e: u8) -> u16 { let g = g as u64; let mut result = 1; @@ -187,12 +186,11 @@ const fn pow_mod_q(g: u16, e: u8) -> u16 { } -#[allow(clippy::cast_possible_truncation)] // i as u8 const fn gen_zeta_table() -> [u16; 256] { let mut result = [0u16; 256]; let mut i = 0; while i < 256u16 { - result[i as usize] = pow_mod_q(ZETA, (i as u8).reverse_bits()); + result[i as usize] = pow_mod_q(ZETA, (i.to_le_bytes()[0]).reverse_bits()); i += 1; } result diff --git a/src/sampling.rs b/src/sampling.rs index efda021..09c537e 100644 --- a/src/sampling.rs +++ b/src/sampling.rs @@ -24,7 +24,6 @@ pub(crate) fn sample_ntt(mut byte_stream_b: impl XofReader) -> Result<[Z; 256], // The proportion of fails is approx 3.098e-12 or 2**{-38}; re-run with fresh randomness. // See cdf at https://www.wolframalpha.com/input?i=binomial+distribution+calculator&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22x%22%7D+-%3E%22256%22&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22n%22%7D+-%3E%22384%22&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22p%22%7D+-%3E%223329%2F4095%22 // 3: while j < 256 do --> this is adapted for constant-time operation - // #[allow(clippy::cast_possible_truncation)] // mask as u16 for _k in 0..192 { // // Note: two samples (d1, d2) are drawn per loop iteration @@ -110,6 +109,7 @@ pub(crate) fn sample_poly_cbd(eta: u32, byte_array_b: &[u8]) -> [Z; 256] { } +// the u types below and above could use a bit more thought // Count u8 ones in constant time (u32 helps perf) #[allow(clippy::cast_possible_truncation)] // return res as u16 fn count_ones(x: u32) -> u16 { diff --git a/src/traits.rs b/src/traits.rs index 68c8246..14160a8 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,4 +1,5 @@ use rand_core::CryptoRngCore; + #[cfg(feature = "default-rng")] use rand_core::OsRng; @@ -16,10 +17,10 @@ pub trait KeyGen { /// Generates an encapsulation and decapsulation key key pair specific to this security parameter set.
- /// This function utilizes the OS default random number generator, and makes no (constant) - /// timing assurances. + /// This function utilizes the OS default random number generator and is intended to operate in constant + /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release) /// # Errors - /// Returns an error when the random number generator fails; propagates internal errors. + /// Returns an error when the random number generator fails, or when the internal sampling overruns. /// # Examples /// ```rust /// # use std::error::Error; @@ -51,10 +52,10 @@ pub trait KeyGen { /// Generates an encapsulation and decapsulation key key pair specific to this security parameter set.
- /// This function utilizes a supplied random number generator, and makes no (constant) - /// timing assurances. + /// This function utilizes a provided random number generator and is intended to operate in constant + /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release) /// # Errors - /// Returns an error when the random number generator fails; propagates internal errors. + /// Returns an error when the random number generator fails, or when the internal sampling overruns. /// # Examples /// ```rust /// # use std::error::Error; @@ -85,7 +86,8 @@ pub trait KeyGen { ) -> Result<(Self::EncapsKey, Self::DecapsKey), &'static str>; - /// Performs validation between an encapsulation key and a decapsulation key. + /// Performs validation between an encapsulation key and a decapsulation key (both in bytes). This function is + /// not intended to operate in constant-time. /// # Examples /// ```rust /// # use std::error::Error; @@ -114,10 +116,11 @@ pub trait Encaps { /// Generates a shared secret and ciphertext from an encapsulation key specific to this security parameter set.
- /// This function utilizes the OS default random number generator, and makes no (constant) - /// timing assurances. + /// This function utilizes the OS default random number generator and is intended to operate in constant + /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release) /// # Errors - /// Returns an error when the random number generator fails; propagates internal errors. + /// Returns an error when the random number generator fails, a malformed encaps key is provided, an internal + /// sampling overrun occurs, along with any other internal errors. /// # Examples /// ```rust /// # use std::error::Error; @@ -150,10 +153,11 @@ pub trait Encaps { /// Generates a shared secret and ciphertext from an encapsulation key specific to this security parameter set.
- /// This function utilizes a supplied random number generator, and makes no (constant) - /// timing assurances. + /// This function utilizes a provided random number generator and is intended to operate in constant + /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release) /// # Errors - /// Returns an error when the random number generator fails; propagates internal errors. + /// Returns an error when the random number generator fails, a malformed encaps key is provided, an internal + /// sampling overrun occurs, along with any other internal errors. /// # Examples /// ```rust /// # use std::error::Error; @@ -194,7 +198,8 @@ pub trait Decaps { /// Generates a shared secret from a decapsulation key and ciphertext specific to this security parameter set.
- /// This function makes no (constant) timing assurances. + /// This function is intended to operate in constant-time. (the function suffix will change to `_ct` in the + /// forthcoming 0.2.0 release) /// # Errors /// Returns an error when the random number generator fails; propagates internal errors. /// # Examples diff --git a/src/types.rs b/src/types.rs index 3320869..835f60c 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,7 +1,7 @@ +use crate::Q; use subtle::ConditionallySelectable; use zeroize::{Zeroize, ZeroizeOnDrop}; -use crate::Q; /// Correctly sized encapsulation key specific to the target security parameter set. #[derive(Clone, Zeroize, ZeroizeOnDrop)] @@ -22,7 +22,7 @@ pub struct CipherText(pub(crate) [u8; CT_LEN]); // While Z is nice, simple and correct, the performance is suboptimal. -// This will be addressed (particularly in matrix operations etc) over, +// This will be addressed (particularly in matrix operations etc) over // the medium-term - potentially as a 256-entry row. /// Stored as u16, but arithmetic as u32 (so we can multiply/reduce/etc) @@ -34,8 +34,6 @@ pub(crate) struct Z(u16); impl Z { const M: u64 = 2u64.pow(32) / (Q as u64); - #[allow(clippy::cast_possible_truncation)] - pub(crate) fn get_u16(self) -> u16 { self.0 } pub(crate) fn get_u32(self) -> u32 { u32::from(self.0) } @@ -47,17 +45,19 @@ impl Z { let sum = self.0.wrapping_add(other.0); let (trial, borrow) = sum.overflowing_sub(Q); let result = u16::conditional_select(&trial, &sum, u8::from(borrow).into()); + debug_assert!(result < Q); Self(result) } #[inline(always)] - #[allow(clippy::cast_possible_truncation)] + #[allow(clippy::cast_possible_truncation)] // for perf pub(crate) fn or(self, other: u32) -> Self { Self(self.0 | other as u16) } #[inline(always)] pub(crate) fn sub(self, other: Self) -> Self { let (diff, borrow) = self.0.overflowing_sub(other.0); let result = u16::conditional_select(&diff, &diff.wrapping_add(Q), u8::from(borrow).into()); + debug_assert!(result < Q); Self(result) } @@ -66,10 +66,11 @@ impl Z { pub(crate) fn mul(self, other: Self) -> Self { let prod = u64::from(self.0) * u64::from(other.0); let quot = prod * Self::M; - let quot = quot >> (32); + let quot = quot >> 32; let rem = prod - quot * u64::from(Q); let (diff, borrow) = (rem as u16).overflowing_sub(Q); let result = u16::conditional_select(&diff, &diff.wrapping_add(Q), u8::from(borrow).into()); + debug_assert!(result < Q); Self(result) } } diff --git a/tests/fails.rs b/tests/fails.rs index 2a034e5..1442fad 100644 --- a/tests/fails.rs +++ b/tests/fails.rs @@ -1,9 +1,8 @@ +use fips203::ml_kem_512; +use fips203::traits::{KeyGen, SerDes}; use rand_chacha::rand_core::SeedableRng; use rand_core::RngCore; -use fips203::ml_kem_512; -use fips203::traits::{Decaps, KeyGen, SerDes}; - // Highlights potential validation opportunities #[test] fn fails_512() { @@ -16,7 +15,7 @@ fn fails_512() { let mut bad_ct_bytes = [0u8; ml_kem_512::CT_LEN]; rng.fill_bytes(&mut bad_ct_bytes); - let bad_ct = ml_kem_512::CipherText::try_from_bytes(bad_ct_bytes); + let _bad_ct = ml_kem_512::CipherText::try_from_bytes(bad_ct_bytes); // Note: FIPS 203 validation per page 31 only puts size constraints on the ciphertext. // A Result is used to allow for future expansion of validation... // assert!(bad_ct.is_err()); @@ -26,12 +25,12 @@ fn fails_512() { let bad_dk = ml_kem_512::DecapsKey::try_from_bytes(bad_dk_bytes); // Note: FIPS 203 validation per page 31 only puts size constraints on the decaps key. // A Result is used to allow for future expansion of validation... - // assert!(bad_dk.is_err()); + assert!(bad_dk.is_err()); // We can validate the non-correspondence of these serialized keypair assert!(!ml_kem_512::KG::validate_keypair_vt(&bad_ek_bytes, &bad_dk_bytes)); - let bad_ssk_bytes = bad_dk.unwrap().try_decaps_vt(&bad_ct.unwrap()); - assert!(bad_ssk_bytes.is_err()); + // let bad_ssk_bytes = bad_dk.unwrap().try_decaps_vt(&bad_ct.unwrap()); + // assert!(bad_ssk_bytes.is_err()); } } diff --git a/tests/native.rs b/tests/native.rs index b71398b..12e02b9 100644 --- a/tests/native.rs +++ b/tests/native.rs @@ -1,8 +1,8 @@ -use rand_core::SeedableRng; - use fips203::ml_kem_512; use fips203::traits::{Decaps, Encaps, KeyGen, SerDes}; use hex_literal::hex; +use rand_core::SeedableRng; + #[test] fn wasm_match() {