diff --git a/fuzz/README.md b/fuzz/README.md
index 1a0baf6..32c764d 100644
--- a/fuzz/README.md
+++ b/fuzz/README.md
@@ -18,7 +18,7 @@ $ cargo fuzz run fuzz_all -j 4
Coverage status of ml_kem_512 is robust, see:
~~~
-#7851: cov: 6312 ft: 3969 corp: 26 exec/s 4 oom/timeout/crash: 0/0/0 time: 843s job: 55 dft_time: 0
+#3543: cov: 6156 ft: 4187 corp: 31 exec/s 5 oom/timeout/crash: 0/0/0 time: 170s job: 33 dft_time: 0
# Warning: the following tools are tricky to install/configure
$ cargo install cargo-cov
diff --git a/src/byte_fns.rs b/src/byte_fns.rs
index 80dc94d..9d5322d 100644
--- a/src/byte_fns.rs
+++ b/src/byte_fns.rs
@@ -25,10 +25,13 @@ use crate::Q;
/// Input: integer array `F ∈ Z^{256}_m`, where `m = 2^d if d < 12` and `m = q if d = 12`
/// Output: byte array B ∈ B^{32·d}
pub(crate) fn byte_encode(d: u32, integers_f: &[Z; 256], bytes_b: &mut [u8]) {
- debug_assert_eq!(bytes_b.len(), 32 * d as usize, "Alg 4: bytes len is not 32 * d");
- debug_assert!(integers_f
- .iter()
- .all(|f| f.get_u16() <= if d < 12 { 1 << d } else { Q }));
+ debug_assert_eq!(bytes_b.len(), 32 * d as usize, "Alg 4: bytes_b len is not 32 * d");
+ debug_assert!(
+ integers_f
+ .iter()
+ .all(|f| f.get_u16() <= if d < 12 { 1 << d } else { Q }),
+ "Alg 4: integers_f out of range"
+ );
//
// Our "working" register, from which to drop bytes out of
let mut temp = 0u32;
@@ -47,11 +50,10 @@ pub(crate) fn byte_encode(d: u32, integers_f: &[Z; 256], bytes_b: &mut [u8]) {
bit_index += d as usize;
// While we have enough bits to drop a byte, do so
- #[allow(clippy::cast_possible_truncation)] // Intentional truncation
while bit_index > 7 {
//
// Drop the byte
- bytes_b[byte_index] = temp as u8;
+ bytes_b[byte_index] = temp.to_le_bytes()[0];
// Update the indices
temp >>= 8;
@@ -87,7 +89,7 @@ pub(crate) fn byte_decode(
bit_index += 8;
// If we have enough bits to drop an int, do so
- #[allow(clippy::cast_possible_truncation)] // Intentional truncation
+ #[allow(clippy::cast_possible_truncation)] // Intentional truncation, temp as u16
while bit_index >= d {
//
// Mask off the upper portion and drop it in
@@ -103,7 +105,7 @@ pub(crate) fn byte_decode(
}
let m = if d < 12 { 1 << d } else { Q };
- ensure!(integers_f.iter().all(|e| e.get_u16() < m), "Alg5: integers out of range");
+ ensure!(integers_f.iter().all(|e| e.get_u16() < m), "Alg 5: integers out of range");
Ok(())
}
diff --git a/src/helpers.rs b/src/helpers.rs
index 033c5f8..c98c592 100644
--- a/src/helpers.rs
+++ b/src/helpers.rs
@@ -39,7 +39,7 @@ pub(crate) fn mul_mat_vec(
) -> [[Z; 256]; K] {
let mut w_hat = [[Z::default(); 256]; K];
for i in 0..K {
- #[allow(clippy::needless_range_loop)]
+ #[allow(clippy::needless_range_loop)] // alternative is harder to understand
for j in 0..K {
let tmp = multiply_ntts(&a_hat[i][j], &u_hat[j]);
for k in 0..256 {
@@ -57,9 +57,9 @@ pub(crate) fn mul_mat_t_vec(
a_hat: &[[[Z; 256]; K]; K], u_hat: &[[Z; 256]; K],
) -> [[Z; 256]; K] {
let mut y_hat = [[Z::default(); 256]; K];
- #[allow(clippy::needless_range_loop)]
+ #[allow(clippy::needless_range_loop)] // alternative is harder to understand
for i in 0..K {
- #[allow(clippy::needless_range_loop)]
+ #[allow(clippy::needless_range_loop)] // alternative is harder to understand
for j in 0..K {
let tmp = multiply_ntts(&a_hat[j][i], &u_hat[j]);
for k in 0..256 {
@@ -152,7 +152,7 @@ pub(crate) fn compress(d: u32, inout: &mut [Z]) {
// Barrett constants should be resolved at compile time
let q64 = u64::from(Q);
let k = 32;
- let m = 2u64.pow(k) / q64;
+ let m = (1 << k) / q64;
// Barrett division, quotient could be too small by one
let top = u64::from(x_ref.get_u32()) << d;
let quot = (top * m) >> k;
@@ -171,7 +171,7 @@ pub(crate) fn compress(d: u32, inout: &mut [Z]) {
#[allow(clippy::cast_possible_truncation)] // last line
pub(crate) fn decompress(d: u32, inout: &mut [Z]) {
for y_ref in &mut *inout {
- let qy = u32::from(Q) * y_ref.get_u32() + 2u32.pow(d) - 1;
+ let qy = u32::from(Q) * y_ref.get_u32() + (1 << d) - 1;
y_ref.set_u16((qy >> d) as u16);
}
}
diff --git a/src/k_pke.rs b/src/k_pke.rs
index 23644be..7aaeac8 100644
--- a/src/k_pke.rs
+++ b/src/k_pke.rs
@@ -13,7 +13,7 @@ use crate::types::Z;
///
/// Output: encryption key `ekPKE ∈ B^{384·k+32}`
/// Output: decryption key `dkPKE ∈ B^{384·k}`
-#[allow(clippy::similar_names, clippy::module_name_repetitions)]
+#[allow(clippy::similar_names)]
pub(crate) fn k_pke_key_gen(
rng: &mut impl CryptoRngCore, eta1: u32, ek_pke: &mut [u8], dk_pke: &mut [u8],
) -> Result<(), &'static str> {
@@ -36,12 +36,11 @@ pub(crate) fn k_pke_key_gen(
for (i, row) in a_hat.iter_mut().enumerate().take(K) {
//
// 5: for (j ← 0; j < k; j++)
- #[allow(clippy::cast_possible_truncation)] // i and j as u8
for (j, entry) in row.iter_mut().enumerate().take(K) {
//
// 6: A_hat[i, j] ← SampleNTT(XOF(ρ, i, j)) ▷ each entry of  uniform in NTT domain
// See page 21 regarding transpose of i, j -? j, i in XOF() https://csrc.nist.gov/files/pubs/fips/203/ipd/docs/fips-203-initial-public-comments-2023.pdf
- *entry = sample_ntt(xof(&rho, j as u8, i as u8))?;
+ *entry = sample_ntt(xof(&rho, j.to_le_bytes()[0], i.to_le_bytes()[0]))?;
// 7: end for
}
@@ -118,10 +117,10 @@ pub(crate) fn k_pke_encrypt Result<(), &'static str> {
- debug_assert_eq!(ek.len(), 384 * K + 32, "Alg13: ek len not 384 * K + 32");
- debug_assert_eq!(m.len(), 32, "Alg13: m len not 32");
- debug_assert_eq!(eta1 as usize * 64, ETA1_64, "Alg13: eta1 size mismatch");
- debug_assert_eq!(eta2 as usize * 64, ETA2_64, "Alg13: eta2 size mismatch");
+ debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 13: ek len not 384 * K + 32");
+ debug_assert_eq!(m.len(), 32, "Alg 13: m len not 32");
+ debug_assert_eq!(eta1 as usize * 64, ETA1_64, "Alg 13: eta1 size mismatch");
+ debug_assert_eq!(eta2 as usize * 64, ETA2_64, "Alg 13: eta2 size mismatch");
// 1: N ← 0
let mut n = 0;
@@ -132,7 +131,7 @@ pub(crate) fn k_pke_encrypt(
du: u32, dv: u32, dk: &[u8], ct: &[u8],
) -> Result<[u8; 32], &'static str> {
- debug_assert_eq!(dk.len(), 384 * K, "Alg14: dk len not 384 * K");
+ debug_assert_eq!(dk.len(), 384 * K, "Alg 14: dk len not 384 * K");
debug_assert_eq!(
ct.len(),
32 * (du as usize * K + dv as usize),
- "Alg14: ct len not 32 * (DU * K + DV)"
+ "Alg 14: ct len not 32 * (DU * K + DV)"
);
// 1: c1 ← c[0 : 32du k]
diff --git a/src/lib.rs b/src/lib.rs
index 895d7f9..62b55c4 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -16,6 +16,12 @@
// Implements FIPS 203 draft Module-Lattice-based Key-Encapsulation Mechanism Standard.
// See
+// TODO Roadmap
+// 0. Stay current with FIPS 203 updates
+// 1. Perf: optimize/minimize modular reductions, minimize u16 arith, consider avx2/aarch64
+// (currently, code is 'optimized' for safety and change-support, with reasonable perf)
+// 2. Slightly more intelligent fuzzing (as dk contains h(ek))
+
// Functionality map per FIPS 203 draft
//
// Algorithm 2 BitsToBytes(b) on page 17 --> optimized out (byte_fns.rs)
@@ -66,6 +72,7 @@ pub mod traits;
const Q: u16 = 3329;
const ZETA: u16 = 17;
+
/// Shared Secret Key length for all ML-KEM variants (in bytes)
pub const SSK_LEN: usize = 32;
@@ -103,7 +110,7 @@ impl PartialEq for SharedSecretKey {
macro_rules! functionality {
() => {
use crate::byte_fns::byte_decode;
- use crate::helpers::h;
+ use crate::helpers::{ensure, h};
use crate::ml_kem::{ml_kem_decaps, ml_kem_encaps, ml_kem_key_gen};
use crate::traits::{Decaps, Encaps, KeyGen, SerDes};
use crate::types::Z;
@@ -139,9 +146,12 @@ macro_rules! functionality {
}
fn validate_keypair_vt(ek: &Self::EncapsByteArray, dk: &Self::DecapsByteArray) -> bool {
+ // Note that size is checked by only accepting ref to correctly sized byte array
let len_ek_pke = 384 * K + 32;
let len_dk_pke = 384 * K;
+ // dk should contain ek
let same_ek = (*ek == dk[len_dk_pke..(len_dk_pke + len_ek_pke)]);
+ // dk should contain hash of ek
let same_h =
(h(ek) == dk[(len_dk_pke + len_ek_pke)..(len_dk_pke + len_ek_pke + 32)]);
same_ek & same_h
@@ -208,8 +218,17 @@ macro_rules! functionality {
fn try_from_bytes(dk: Self::ByteArray) -> Result {
// Validation per pg 31. Note that the two checks specify fixed sizes, and these
- // functions take only byte arrays of correct size. Nonetheless, we use a Result
- // here in case future opportunities for further validation arise.
+ // functions take only byte arrays of correct size. Nonetheless, we take the
+ // opportunity to validate the ek and h(ek).
+ let len_ek_pke = 384 * K + 32;
+ let len_dk_pke = 384 * K;
+ let ek = &dk[len_dk_pke..len_dk_pke + EK_LEN];
+ let _res =
+ EncapsKey::try_from_bytes(ek.try_into().map_err(|_| "Malformed encaps key")?)?;
+ ensure!(
+ h(ek) == dk[(len_dk_pke + len_ek_pke)..(len_dk_pke + len_ek_pke + 32)],
+ "Encaps hash wrong"
+ );
Ok(DecapsKey { 0: dk })
}
}
@@ -232,17 +251,20 @@ macro_rules! functionality {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::types::EncapsKey;
use rand_chacha::rand_core::SeedableRng;
#[test]
fn smoke_test() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
- for _i in 0..10 {
+ for _i in 0..100 {
let (ek, dk) = KG::try_keygen_with_rng_vt(&mut rng).unwrap();
let (ssk1, ct) = ek.try_encaps_with_rng_vt(&mut rng).unwrap();
let ssk2 = dk.try_decaps_vt(&ct).unwrap();
- assert!(KG::validate_keypair_vt(&ek.into_bytes(), &dk.into_bytes()));
- assert_eq!(ssk1, ssk2, "Shared secrets differ");
+ assert!(KG::validate_keypair_vt(&ek.clone().into_bytes(), &dk.into_bytes()));
+ assert_eq!(ssk1, ssk2);
+ assert_eq!(ek.clone().0, EncapsKey::try_from_bytes(ek.into_bytes()).unwrap().0);
+ // the other SerDes routines don't really have logic...
}
}
}
diff --git a/src/ml_kem.rs b/src/ml_kem.rs
index 0b79fdd..be095ba 100644
--- a/src/ml_kem.rs
+++ b/src/ml_kem.rs
@@ -1,5 +1,5 @@
use crate::byte_fns::{byte_decode, byte_encode};
-use crate::helpers::{ensure, g, h, j};
+use crate::helpers::{g, h, j};
use crate::k_pke::{k_pke_decrypt, k_pke_encrypt, k_pke_key_gen};
use crate::types::Z;
use crate::SharedSecretKey;
@@ -13,13 +13,13 @@ use rand_core::CryptoRngCore;
pub(crate) fn ml_kem_key_gen(
rng: &mut impl CryptoRngCore, eta1: u32, ek: &mut [u8], dk: &mut [u8],
) -> Result<(), &'static str> {
- debug_assert_eq!(ek.len(), 384 * K + 32, "Alg15: ek len not 384 * K + 32");
- debug_assert_eq!(dk.len(), 768 * K + 96, "Alg15: dk len not 768 * K + 96");
+ debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 15: ek len not 384 * K + 32");
+ debug_assert_eq!(dk.len(), 768 * K + 96, "Alg 15: dk len not 768 * K + 96");
// 1: z ←− B32 ▷ z is 32 random bytes (see Section 3.3)
let mut z = [0u8; 32];
rng.try_fill_bytes(&mut z)
- .map_err(|_| "Alg15: Random number generator failed")?;
+ .map_err(|_| "Alg 15: Random number generator failed")?;
// 2: (ek_{PKE}, dk_{PKE}) ← K-PKE.KeyGen() ▷ run key generation for K-PKE
let p1 = 384 * K;
@@ -47,11 +47,11 @@ pub(crate) fn ml_kem_key_gen(
pub(crate) fn ml_kem_encaps(
rng: &mut impl CryptoRngCore, du: u32, dv: u32, eta1: u32, eta2: u32, ek: &[u8], ct: &mut [u8],
) -> Result {
- debug_assert_eq!(ek.len(), 384 * K + 32, "Alg16: ek len not 384 * K + 32"); // also: size check at top level
+ debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 16: ek len not 384 * K + 32"); // also: size check at top level
debug_assert_eq!(
ct.len(),
32 * (du as usize * K + dv as usize),
- "Alg16: ct len not 32*(DU*K+DV)"
+ "Alg 16: ct len not 32*(DU*K+DV)"
); // also: size check at top level
// modulus check: perform the computation ek ← ByteEncode12(ByteDecode12(ek_tidle)
@@ -68,7 +68,7 @@ pub(crate) fn ml_kem_encaps Result {
// These length checks are a bit redundant...but present for completeness and paranoia
- ensure!(ct.len() == 32 * (du as usize * K + dv as usize), "Alg17: ct len not 32 * ...");
+ debug_assert_eq!(ct.len(), 32 * (du as usize * K + dv as usize), "Alg17: ct len not 32 * ...");
// Ciphertext type check
- ensure!(dk.len() == 768 * K + 96, "Alg17: dk len not 768 ..."); // Decapsulation key type check
+ debug_assert_eq!(dk.len(), 768 * K + 96, "Alg17: dk len not 768 ..."); // Decapsulation key type check
// 1019 For some applications, further validation of the decapsulation key dk_tilde may be appropriate. For
// 1020 instance, in cases where dk_tilde was generated by a third party, users may want to ensure that the four
diff --git a/src/ntt.rs b/src/ntt.rs
index 59c783e..c6cbf61 100644
--- a/src/ntt.rs
+++ b/src/ntt.rs
@@ -121,7 +121,6 @@ pub(crate) fn ntt_inv(f_hat: &[Z; 256]) -> [Z; 256] {
/// Input: Two arrays `f_hat` ∈ `Z^{256}_q` and `g_hat` ∈ `Z^{256}_q` ▷ the coefficients of two NTT representations
/// Output: An array `h_hat` ∈ `Z^{256}_q` ▷ the coefficients of the product of the inputs
#[must_use]
-#[allow(clippy::cast_possible_truncation)]
pub(crate) fn multiply_ntts(f_hat: &[Z; 256], g_hat: &[Z; 256]) -> [Z; 256] {
let mut h_hat: [Z; 256] = [Z::default(); 256];
@@ -168,7 +167,7 @@ pub(crate) fn base_case_multiply(a0: Z, a1: Z, b0: Z, b1: Z, gamma: Z) -> (Z, Z)
/// HAC Algorithm 14.76 Right-to-left binary exponentiation mod Q.
#[must_use]
-#[allow(clippy::cast_possible_truncation)] // on result
+#[allow(clippy::cast_possible_truncation)] // on result as u16 (try_from not const)
const fn pow_mod_q(g: u16, e: u8) -> u16 {
let g = g as u64;
let mut result = 1;
@@ -187,12 +186,11 @@ const fn pow_mod_q(g: u16, e: u8) -> u16 {
}
-#[allow(clippy::cast_possible_truncation)] // i as u8
const fn gen_zeta_table() -> [u16; 256] {
let mut result = [0u16; 256];
let mut i = 0;
while i < 256u16 {
- result[i as usize] = pow_mod_q(ZETA, (i as u8).reverse_bits());
+ result[i as usize] = pow_mod_q(ZETA, (i.to_le_bytes()[0]).reverse_bits());
i += 1;
}
result
diff --git a/src/sampling.rs b/src/sampling.rs
index efda021..09c537e 100644
--- a/src/sampling.rs
+++ b/src/sampling.rs
@@ -24,7 +24,6 @@ pub(crate) fn sample_ntt(mut byte_stream_b: impl XofReader) -> Result<[Z; 256],
// The proportion of fails is approx 3.098e-12 or 2**{-38}; re-run with fresh randomness.
// See cdf at https://www.wolframalpha.com/input?i=binomial+distribution+calculator&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22x%22%7D+-%3E%22256%22&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22n%22%7D+-%3E%22384%22&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22p%22%7D+-%3E%223329%2F4095%22
// 3: while j < 256 do --> this is adapted for constant-time operation
- // #[allow(clippy::cast_possible_truncation)] // mask as u16
for _k in 0..192 {
//
// Note: two samples (d1, d2) are drawn per loop iteration
@@ -110,6 +109,7 @@ pub(crate) fn sample_poly_cbd(eta: u32, byte_array_b: &[u8]) -> [Z; 256] {
}
+// the u types below and above could use a bit more thought
// Count u8 ones in constant time (u32 helps perf)
#[allow(clippy::cast_possible_truncation)] // return res as u16
fn count_ones(x: u32) -> u16 {
diff --git a/src/traits.rs b/src/traits.rs
index 68c8246..14160a8 100644
--- a/src/traits.rs
+++ b/src/traits.rs
@@ -1,4 +1,5 @@
use rand_core::CryptoRngCore;
+
#[cfg(feature = "default-rng")]
use rand_core::OsRng;
@@ -16,10 +17,10 @@ pub trait KeyGen {
/// Generates an encapsulation and decapsulation key key pair specific to this security parameter set.
- /// This function utilizes the OS default random number generator, and makes no (constant)
- /// timing assurances.
+ /// This function utilizes the OS default random number generator and is intended to operate in constant
+ /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release)
/// # Errors
- /// Returns an error when the random number generator fails; propagates internal errors.
+ /// Returns an error when the random number generator fails, or when the internal sampling overruns.
/// # Examples
/// ```rust
/// # use std::error::Error;
@@ -51,10 +52,10 @@ pub trait KeyGen {
/// Generates an encapsulation and decapsulation key key pair specific to this security parameter set.
- /// This function utilizes a supplied random number generator, and makes no (constant)
- /// timing assurances.
+ /// This function utilizes a provided random number generator and is intended to operate in constant
+ /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release)
/// # Errors
- /// Returns an error when the random number generator fails; propagates internal errors.
+ /// Returns an error when the random number generator fails, or when the internal sampling overruns.
/// # Examples
/// ```rust
/// # use std::error::Error;
@@ -85,7 +86,8 @@ pub trait KeyGen {
) -> Result<(Self::EncapsKey, Self::DecapsKey), &'static str>;
- /// Performs validation between an encapsulation key and a decapsulation key.
+ /// Performs validation between an encapsulation key and a decapsulation key (both in bytes). This function is
+ /// not intended to operate in constant-time.
/// # Examples
/// ```rust
/// # use std::error::Error;
@@ -114,10 +116,11 @@ pub trait Encaps {
/// Generates a shared secret and ciphertext from an encapsulation key specific to this security parameter set.
- /// This function utilizes the OS default random number generator, and makes no (constant)
- /// timing assurances.
+ /// This function utilizes the OS default random number generator and is intended to operate in constant
+ /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release)
/// # Errors
- /// Returns an error when the random number generator fails; propagates internal errors.
+ /// Returns an error when the random number generator fails, a malformed encaps key is provided, an internal
+ /// sampling overrun occurs, along with any other internal errors.
/// # Examples
/// ```rust
/// # use std::error::Error;
@@ -150,10 +153,11 @@ pub trait Encaps {
/// Generates a shared secret and ciphertext from an encapsulation key specific to this security parameter set.
- /// This function utilizes a supplied random number generator, and makes no (constant)
- /// timing assurances.
+ /// This function utilizes a provided random number generator and is intended to operate in constant
+ /// time. (the function suffix will change to `_ct` in the forthcoming 0.2.0 release)
/// # Errors
- /// Returns an error when the random number generator fails; propagates internal errors.
+ /// Returns an error when the random number generator fails, a malformed encaps key is provided, an internal
+ /// sampling overrun occurs, along with any other internal errors.
/// # Examples
/// ```rust
/// # use std::error::Error;
@@ -194,7 +198,8 @@ pub trait Decaps {
/// Generates a shared secret from a decapsulation key and ciphertext specific to this security parameter set.
- /// This function makes no (constant) timing assurances.
+ /// This function is intended to operate in constant-time. (the function suffix will change to `_ct` in the
+ /// forthcoming 0.2.0 release)
/// # Errors
/// Returns an error when the random number generator fails; propagates internal errors.
/// # Examples
diff --git a/src/types.rs b/src/types.rs
index 3320869..835f60c 100644
--- a/src/types.rs
+++ b/src/types.rs
@@ -1,7 +1,7 @@
+use crate::Q;
use subtle::ConditionallySelectable;
use zeroize::{Zeroize, ZeroizeOnDrop};
-use crate::Q;
/// Correctly sized encapsulation key specific to the target security parameter set.
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
@@ -22,7 +22,7 @@ pub struct CipherText(pub(crate) [u8; CT_LEN]);
// While Z is nice, simple and correct, the performance is suboptimal.
-// This will be addressed (particularly in matrix operations etc) over,
+// This will be addressed (particularly in matrix operations etc) over
// the medium-term - potentially as a 256-entry row.
/// Stored as u16, but arithmetic as u32 (so we can multiply/reduce/etc)
@@ -34,8 +34,6 @@ pub(crate) struct Z(u16);
impl Z {
const M: u64 = 2u64.pow(32) / (Q as u64);
- #[allow(clippy::cast_possible_truncation)]
-
pub(crate) fn get_u16(self) -> u16 { self.0 }
pub(crate) fn get_u32(self) -> u32 { u32::from(self.0) }
@@ -47,17 +45,19 @@ impl Z {
let sum = self.0.wrapping_add(other.0);
let (trial, borrow) = sum.overflowing_sub(Q);
let result = u16::conditional_select(&trial, &sum, u8::from(borrow).into());
+ debug_assert!(result < Q);
Self(result)
}
#[inline(always)]
- #[allow(clippy::cast_possible_truncation)]
+ #[allow(clippy::cast_possible_truncation)] // for perf
pub(crate) fn or(self, other: u32) -> Self { Self(self.0 | other as u16) }
#[inline(always)]
pub(crate) fn sub(self, other: Self) -> Self {
let (diff, borrow) = self.0.overflowing_sub(other.0);
let result = u16::conditional_select(&diff, &diff.wrapping_add(Q), u8::from(borrow).into());
+ debug_assert!(result < Q);
Self(result)
}
@@ -66,10 +66,11 @@ impl Z {
pub(crate) fn mul(self, other: Self) -> Self {
let prod = u64::from(self.0) * u64::from(other.0);
let quot = prod * Self::M;
- let quot = quot >> (32);
+ let quot = quot >> 32;
let rem = prod - quot * u64::from(Q);
let (diff, borrow) = (rem as u16).overflowing_sub(Q);
let result = u16::conditional_select(&diff, &diff.wrapping_add(Q), u8::from(borrow).into());
+ debug_assert!(result < Q);
Self(result)
}
}
diff --git a/tests/fails.rs b/tests/fails.rs
index 2a034e5..1442fad 100644
--- a/tests/fails.rs
+++ b/tests/fails.rs
@@ -1,9 +1,8 @@
+use fips203::ml_kem_512;
+use fips203::traits::{KeyGen, SerDes};
use rand_chacha::rand_core::SeedableRng;
use rand_core::RngCore;
-use fips203::ml_kem_512;
-use fips203::traits::{Decaps, KeyGen, SerDes};
-
// Highlights potential validation opportunities
#[test]
fn fails_512() {
@@ -16,7 +15,7 @@ fn fails_512() {
let mut bad_ct_bytes = [0u8; ml_kem_512::CT_LEN];
rng.fill_bytes(&mut bad_ct_bytes);
- let bad_ct = ml_kem_512::CipherText::try_from_bytes(bad_ct_bytes);
+ let _bad_ct = ml_kem_512::CipherText::try_from_bytes(bad_ct_bytes);
// Note: FIPS 203 validation per page 31 only puts size constraints on the ciphertext.
// A Result is used to allow for future expansion of validation...
// assert!(bad_ct.is_err());
@@ -26,12 +25,12 @@ fn fails_512() {
let bad_dk = ml_kem_512::DecapsKey::try_from_bytes(bad_dk_bytes);
// Note: FIPS 203 validation per page 31 only puts size constraints on the decaps key.
// A Result is used to allow for future expansion of validation...
- // assert!(bad_dk.is_err());
+ assert!(bad_dk.is_err());
// We can validate the non-correspondence of these serialized keypair
assert!(!ml_kem_512::KG::validate_keypair_vt(&bad_ek_bytes, &bad_dk_bytes));
- let bad_ssk_bytes = bad_dk.unwrap().try_decaps_vt(&bad_ct.unwrap());
- assert!(bad_ssk_bytes.is_err());
+ // let bad_ssk_bytes = bad_dk.unwrap().try_decaps_vt(&bad_ct.unwrap());
+ // assert!(bad_ssk_bytes.is_err());
}
}
diff --git a/tests/native.rs b/tests/native.rs
index b71398b..12e02b9 100644
--- a/tests/native.rs
+++ b/tests/native.rs
@@ -1,8 +1,8 @@
-use rand_core::SeedableRng;
-
use fips203::ml_kem_512;
use fips203::traits::{Decaps, Encaps, KeyGen, SerDes};
use hex_literal::hex;
+use rand_core::SeedableRng;
+
#[test]
fn wasm_match() {