Skip to content

Commit

Permalink
code opts 2
Browse files Browse the repository at this point in the history
  • Loading branch information
eschorn1 committed Apr 1, 2024
1 parent 814cf41 commit 3834004
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 76 deletions.
2 changes: 1 addition & 1 deletion fuzz/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ $ cargo fuzz run fuzz_all -j 4
Coverage status of ml_kem_512 is robust, see:

~~~
#7851: cov: 6312 ft: 3969 corp: 26 exec/s 4 oom/timeout/crash: 0/0/0 time: 843s job: 55 dft_time: 0
#3543: cov: 6156 ft: 4187 corp: 31 exec/s 5 oom/timeout/crash: 0/0/0 time: 170s job: 33 dft_time: 0
# Warning: the following tools are tricky to install/configure
$ cargo install cargo-cov
Expand Down
18 changes: 10 additions & 8 deletions src/byte_fns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ use crate::Q;
/// Input: integer array `F ∈ Z^{256}_m`, where `m = 2^d if d < 12` and `m = q if d = 12` <br>
/// Output: byte array B ∈ B^{32·d}
pub(crate) fn byte_encode(d: u32, integers_f: &[Z; 256], bytes_b: &mut [u8]) {
debug_assert_eq!(bytes_b.len(), 32 * d as usize, "Alg 4: bytes len is not 32 * d");
debug_assert!(integers_f
.iter()
.all(|f| f.get_u16() <= if d < 12 { 1 << d } else { Q }));
debug_assert_eq!(bytes_b.len(), 32 * d as usize, "Alg 4: bytes_b len is not 32 * d");
debug_assert!(
integers_f
.iter()
.all(|f| f.get_u16() <= if d < 12 { 1 << d } else { Q }),
"Alg 4: integers_f out of range"
);
//
// Our "working" register, from which to drop bytes out of
let mut temp = 0u32;
Expand All @@ -47,11 +50,10 @@ pub(crate) fn byte_encode(d: u32, integers_f: &[Z; 256], bytes_b: &mut [u8]) {
bit_index += d as usize;

// While we have enough bits to drop a byte, do so
#[allow(clippy::cast_possible_truncation)] // Intentional truncation
while bit_index > 7 {
//
// Drop the byte
bytes_b[byte_index] = temp as u8;
bytes_b[byte_index] = temp.to_le_bytes()[0];

// Update the indices
temp >>= 8;
Expand Down Expand Up @@ -87,7 +89,7 @@ pub(crate) fn byte_decode(
bit_index += 8;

// If we have enough bits to drop an int, do so
#[allow(clippy::cast_possible_truncation)] // Intentional truncation
#[allow(clippy::cast_possible_truncation)] // Intentional truncation, temp as u16
while bit_index >= d {
//
// Mask off the upper portion and drop it in
Expand All @@ -103,7 +105,7 @@ pub(crate) fn byte_decode(
}

let m = if d < 12 { 1 << d } else { Q };
ensure!(integers_f.iter().all(|e| e.get_u16() < m), "Alg5: integers out of range");
ensure!(integers_f.iter().all(|e| e.get_u16() < m), "Alg 5: integers out of range");
Ok(())
}

Expand Down
10 changes: 5 additions & 5 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ pub(crate) fn mul_mat_vec<const K: usize>(
) -> [[Z; 256]; K] {
let mut w_hat = [[Z::default(); 256]; K];
for i in 0..K {
#[allow(clippy::needless_range_loop)]
#[allow(clippy::needless_range_loop)] // alternative is harder to understand
for j in 0..K {
let tmp = multiply_ntts(&a_hat[i][j], &u_hat[j]);
for k in 0..256 {
Expand All @@ -57,9 +57,9 @@ pub(crate) fn mul_mat_t_vec<const K: usize>(
a_hat: &[[[Z; 256]; K]; K], u_hat: &[[Z; 256]; K],
) -> [[Z; 256]; K] {
let mut y_hat = [[Z::default(); 256]; K];
#[allow(clippy::needless_range_loop)]
#[allow(clippy::needless_range_loop)] // alternative is harder to understand
for i in 0..K {
#[allow(clippy::needless_range_loop)]
#[allow(clippy::needless_range_loop)] // alternative is harder to understand
for j in 0..K {
let tmp = multiply_ntts(&a_hat[j][i], &u_hat[j]);
for k in 0..256 {
Expand Down Expand Up @@ -152,7 +152,7 @@ pub(crate) fn compress(d: u32, inout: &mut [Z]) {
// Barrett constants should be resolved at compile time
let q64 = u64::from(Q);
let k = 32;
let m = 2u64.pow(k) / q64;
let m = (1 << k) / q64;
// Barrett division, quotient could be too small by one
let top = u64::from(x_ref.get_u32()) << d;
let quot = (top * m) >> k;
Expand All @@ -171,7 +171,7 @@ pub(crate) fn compress(d: u32, inout: &mut [Z]) {
#[allow(clippy::cast_possible_truncation)] // last line
pub(crate) fn decompress(d: u32, inout: &mut [Z]) {
for y_ref in &mut *inout {
let qy = u32::from(Q) * y_ref.get_u32() + 2u32.pow(d) - 1;
let qy = u32::from(Q) * y_ref.get_u32() + (1 << d) - 1;
y_ref.set_u16((qy >> d) as u16);
}
}
24 changes: 11 additions & 13 deletions src/k_pke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::types::Z;
///
/// Output: encryption key `ekPKE ∈ B^{384·k+32}` <br>
/// Output: decryption key `dkPKE ∈ B^{384·k}`
#[allow(clippy::similar_names, clippy::module_name_repetitions)]
#[allow(clippy::similar_names)]
pub(crate) fn k_pke_key_gen<const K: usize, const ETA1_64: usize>(
rng: &mut impl CryptoRngCore, eta1: u32, ek_pke: &mut [u8], dk_pke: &mut [u8],
) -> Result<(), &'static str> {
Expand All @@ -36,12 +36,11 @@ pub(crate) fn k_pke_key_gen<const K: usize, const ETA1_64: usize>(
for (i, row) in a_hat.iter_mut().enumerate().take(K) {
//
// 5: for (j ← 0; j < k; j++)
#[allow(clippy::cast_possible_truncation)] // i and j as u8
for (j, entry) in row.iter_mut().enumerate().take(K) {
//
// 6: A_hat[i, j] ← SampleNTT(XOF(ρ, i, j)) ▷ each entry of  uniform in NTT domain
// See page 21 regarding transpose of i, j -? j, i in XOF() https://csrc.nist.gov/files/pubs/fips/203/ipd/docs/fips-203-initial-public-comments-2023.pdf
*entry = sample_ntt(xof(&rho, j as u8, i as u8))?;
*entry = sample_ntt(xof(&rho, j.to_le_bytes()[0], i.to_le_bytes()[0]))?;

// 7: end for
}
Expand Down Expand Up @@ -118,10 +117,10 @@ pub(crate) fn k_pke_encrypt<const K: usize, const ETA1_64: usize, const ETA2_64:
du: u32, dv: u32, eta1: u32, eta2: u32, ek: &[u8], m: &[u8], randomness: &[u8; 32],
ct: &mut [u8],
) -> Result<(), &'static str> {
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg13: ek len not 384 * K + 32");
debug_assert_eq!(m.len(), 32, "Alg13: m len not 32");
debug_assert_eq!(eta1 as usize * 64, ETA1_64, "Alg13: eta1 size mismatch");
debug_assert_eq!(eta2 as usize * 64, ETA2_64, "Alg13: eta2 size mismatch");
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 13: ek len not 384 * K + 32");
debug_assert_eq!(m.len(), 32, "Alg 13: m len not 32");
debug_assert_eq!(eta1 as usize * 64, ETA1_64, "Alg 13: eta1 size mismatch");
debug_assert_eq!(eta2 as usize * 64, ETA2_64, "Alg 13: eta2 size mismatch");

// 1: N ← 0
let mut n = 0;
Expand All @@ -132,7 +131,7 @@ pub(crate) fn k_pke_encrypt<const K: usize, const ETA1_64: usize, const ETA2_64:
byte_decode(12, &ek[384 * i..384 * (i + 1)], &mut t_hat[i])?;
}

// 3: 3: ρ ← ekPKE [384k : 384k + 32] ▷ extract 32-byte seed from ekPKE
// 3: ρ ← ekPKE [384k : 384k + 32] ▷ extract 32-byte seed from ekPKE
let mut rho = [0u8; 32];
rho.copy_from_slice(&ek[384 * K..(384 * K + 32)]);

Expand All @@ -141,11 +140,10 @@ pub(crate) fn k_pke_encrypt<const K: usize, const ETA1_64: usize, const ETA2_64:
for (i, row) in a_hat.iter_mut().enumerate().take(K) {
//
// 5: for (j ← 0; j < k; j++)
#[allow(clippy::cast_possible_truncation)] // i and j as u8
for (j, entry) in row.iter_mut().enumerate().take(K) {
//
// 6: Â[i, j] ← SampleNTT(XOF(ρ, i, j))
*entry = sample_ntt(xof(&rho, j as u8, i as u8))?;
*entry = sample_ntt(xof(&rho, j.to_le_bytes()[0], i.to_le_bytes()[0]))?;

// 7: end for
}
Expand Down Expand Up @@ -203,7 +201,7 @@ pub(crate) fn k_pke_encrypt<const K: usize, const ETA1_64: usize, const ETA2_64:
let mut v = ntt_inv(&dot_t_prod(&t_hat, &r_hat));
v = add_vecs(&add_vecs(&[v], &[e2]), &[mu])[0];

// 22: c1 ← ByteEncode_{du}(Compress_{du}(u)) ▷ ByteEncodedu is run k times
// 22: c1 ← ByteEncode_{du}(Compress_{du}(u)) ▷ ByteEncode_{du} is run k times
let step = 32 * du as usize;
for i in 0..K {
compress(du, &mut u[i]);
Expand All @@ -228,11 +226,11 @@ pub(crate) fn k_pke_encrypt<const K: usize, const ETA1_64: usize, const ETA2_64:
pub(crate) fn k_pke_decrypt<const K: usize>(
du: u32, dv: u32, dk: &[u8], ct: &[u8],
) -> Result<[u8; 32], &'static str> {
debug_assert_eq!(dk.len(), 384 * K, "Alg14: dk len not 384 * K");
debug_assert_eq!(dk.len(), 384 * K, "Alg 14: dk len not 384 * K");
debug_assert_eq!(
ct.len(),
32 * (du as usize * K + dv as usize),
"Alg14: ct len not 32 * (DU * K + DV)"
"Alg 14: ct len not 32 * (DU * K + DV)"
);

// 1: c1 ← c[0 : 32du k]
Expand Down
34 changes: 28 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
// Implements FIPS 203 draft Module-Lattice-based Key-Encapsulation Mechanism Standard.
// See <https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.203.ipd.pdf>

// TODO Roadmap
// 0. Stay current with FIPS 203 updates
// 1. Perf: optimize/minimize modular reductions, minimize u16 arith, consider avx2/aarch64
// (currently, code is 'optimized' for safety and change-support, with reasonable perf)
// 2. Slightly more intelligent fuzzing (as dk contains h(ek))

// Functionality map per FIPS 203 draft
//
// Algorithm 2 BitsToBytes(b) on page 17 --> optimized out (byte_fns.rs)
Expand Down Expand Up @@ -66,6 +72,7 @@ pub mod traits;
const Q: u16 = 3329;
const ZETA: u16 = 17;


/// Shared Secret Key length for all ML-KEM variants (in bytes)
pub const SSK_LEN: usize = 32;

Expand Down Expand Up @@ -103,7 +110,7 @@ impl PartialEq for SharedSecretKey {
macro_rules! functionality {
() => {
use crate::byte_fns::byte_decode;
use crate::helpers::h;
use crate::helpers::{ensure, h};
use crate::ml_kem::{ml_kem_decaps, ml_kem_encaps, ml_kem_key_gen};
use crate::traits::{Decaps, Encaps, KeyGen, SerDes};
use crate::types::Z;
Expand Down Expand Up @@ -139,9 +146,12 @@ macro_rules! functionality {
}

fn validate_keypair_vt(ek: &Self::EncapsByteArray, dk: &Self::DecapsByteArray) -> bool {
// Note that size is checked by only accepting ref to correctly sized byte array
let len_ek_pke = 384 * K + 32;
let len_dk_pke = 384 * K;
// dk should contain ek
let same_ek = (*ek == dk[len_dk_pke..(len_dk_pke + len_ek_pke)]);
// dk should contain hash of ek
let same_h =
(h(ek) == dk[(len_dk_pke + len_ek_pke)..(len_dk_pke + len_ek_pke + 32)]);
same_ek & same_h
Expand Down Expand Up @@ -208,8 +218,17 @@ macro_rules! functionality {

fn try_from_bytes(dk: Self::ByteArray) -> Result<Self, &'static str> {
// Validation per pg 31. Note that the two checks specify fixed sizes, and these
// functions take only byte arrays of correct size. Nonetheless, we use a Result
// here in case future opportunities for further validation arise.
// functions take only byte arrays of correct size. Nonetheless, we take the
// opportunity to validate the ek and h(ek).
let len_ek_pke = 384 * K + 32;
let len_dk_pke = 384 * K;
let ek = &dk[len_dk_pke..len_dk_pke + EK_LEN];
let _res =
EncapsKey::try_from_bytes(ek.try_into().map_err(|_| "Malformed encaps key")?)?;
ensure!(
h(ek) == dk[(len_dk_pke + len_ek_pke)..(len_dk_pke + len_ek_pke + 32)],
"Encaps hash wrong"
);
Ok(DecapsKey { 0: dk })
}
}
Expand All @@ -232,17 +251,20 @@ macro_rules! functionality {
#[cfg(test)]
mod tests {
use super::*;
use crate::types::EncapsKey;
use rand_chacha::rand_core::SeedableRng;

#[test]
fn smoke_test() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(123);
for _i in 0..10 {
for _i in 0..100 {
let (ek, dk) = KG::try_keygen_with_rng_vt(&mut rng).unwrap();
let (ssk1, ct) = ek.try_encaps_with_rng_vt(&mut rng).unwrap();
let ssk2 = dk.try_decaps_vt(&ct).unwrap();
assert!(KG::validate_keypair_vt(&ek.into_bytes(), &dk.into_bytes()));
assert_eq!(ssk1, ssk2, "Shared secrets differ");
assert!(KG::validate_keypair_vt(&ek.clone().into_bytes(), &dk.into_bytes()));
assert_eq!(ssk1, ssk2);
assert_eq!(ek.clone().0, EncapsKey::try_from_bytes(ek.into_bytes()).unwrap().0);
// the other SerDes routines don't really have logic...
}
}
}
Expand Down
18 changes: 9 additions & 9 deletions src/ml_kem.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::byte_fns::{byte_decode, byte_encode};
use crate::helpers::{ensure, g, h, j};
use crate::helpers::{g, h, j};
use crate::k_pke::{k_pke_decrypt, k_pke_encrypt, k_pke_key_gen};
use crate::types::Z;
use crate::SharedSecretKey;
Expand All @@ -13,13 +13,13 @@ use rand_core::CryptoRngCore;
pub(crate) fn ml_kem_key_gen<const K: usize, const ETA1_64: usize>(
rng: &mut impl CryptoRngCore, eta1: u32, ek: &mut [u8], dk: &mut [u8],
) -> Result<(), &'static str> {
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg15: ek len not 384 * K + 32");
debug_assert_eq!(dk.len(), 768 * K + 96, "Alg15: dk len not 768 * K + 96");
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 15: ek len not 384 * K + 32");
debug_assert_eq!(dk.len(), 768 * K + 96, "Alg 15: dk len not 768 * K + 96");

// 1: z ←− B32 ▷ z is 32 random bytes (see Section 3.3)
let mut z = [0u8; 32];
rng.try_fill_bytes(&mut z)
.map_err(|_| "Alg15: Random number generator failed")?;
.map_err(|_| "Alg 15: Random number generator failed")?;

// 2: (ek_{PKE}, dk_{PKE}) ← K-PKE.KeyGen() ▷ run key generation for K-PKE
let p1 = 384 * K;
Expand Down Expand Up @@ -47,11 +47,11 @@ pub(crate) fn ml_kem_key_gen<const K: usize, const ETA1_64: usize>(
pub(crate) fn ml_kem_encaps<const K: usize, const ETA1_64: usize, const ETA2_64: usize>(
rng: &mut impl CryptoRngCore, du: u32, dv: u32, eta1: u32, eta2: u32, ek: &[u8], ct: &mut [u8],
) -> Result<SharedSecretKey, &'static str> {
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg16: ek len not 384 * K + 32"); // also: size check at top level
debug_assert_eq!(ek.len(), 384 * K + 32, "Alg 16: ek len not 384 * K + 32"); // also: size check at top level
debug_assert_eq!(
ct.len(),
32 * (du as usize * K + dv as usize),
"Alg16: ct len not 32*(DU*K+DV)"
"Alg 16: ct len not 32*(DU*K+DV)"
); // also: size check at top level

// modulus check: perform the computation ek ← ByteEncode12(ByteDecode12(ek_tidle)
Expand All @@ -68,7 +68,7 @@ pub(crate) fn ml_kem_encaps<const K: usize, const ETA1_64: usize, const ETA2_64:
}
pass
},
"Alg16: ek fails modulus check"
"Alg 16: ek fails modulus check"
);

// 1: m ←− B32 ▷ m is 32 random bytes (see Section 3.3)
Expand Down Expand Up @@ -104,9 +104,9 @@ pub(crate) fn ml_kem_decaps<
du: u32, dv: u32, eta1: u32, eta2: u32, dk: &[u8], ct: &[u8],
) -> Result<SharedSecretKey, &'static str> {
// These length checks are a bit redundant...but present for completeness and paranoia
ensure!(ct.len() == 32 * (du as usize * K + dv as usize), "Alg17: ct len not 32 * ...");
debug_assert_eq!(ct.len(), 32 * (du as usize * K + dv as usize), "Alg17: ct len not 32 * ...");
// Ciphertext type check
ensure!(dk.len() == 768 * K + 96, "Alg17: dk len not 768 ..."); // Decapsulation key type check
debug_assert_eq!(dk.len(), 768 * K + 96, "Alg17: dk len not 768 ..."); // Decapsulation key type check

// 1019 For some applications, further validation of the decapsulation key dk_tilde may be appropriate. For
// 1020 instance, in cases where dk_tilde was generated by a third party, users may want to ensure that the four
Expand Down
6 changes: 2 additions & 4 deletions src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ pub(crate) fn ntt_inv(f_hat: &[Z; 256]) -> [Z; 256] {
/// Input: Two arrays `f_hat` ∈ `Z^{256}_q` and `g_hat` ∈ `Z^{256}_q` ▷ the coefficients of two NTT representations <br>
/// Output: An array `h_hat` ∈ `Z^{256}_q` ▷ the coefficients of the product of the inputs
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub(crate) fn multiply_ntts(f_hat: &[Z; 256], g_hat: &[Z; 256]) -> [Z; 256] {
let mut h_hat: [Z; 256] = [Z::default(); 256];

Expand Down Expand Up @@ -168,7 +167,7 @@ pub(crate) fn base_case_multiply(a0: Z, a1: Z, b0: Z, b1: Z, gamma: Z) -> (Z, Z)

/// HAC Algorithm 14.76 Right-to-left binary exponentiation mod Q.
#[must_use]
#[allow(clippy::cast_possible_truncation)] // on result
#[allow(clippy::cast_possible_truncation)] // on result as u16 (try_from not const)
const fn pow_mod_q(g: u16, e: u8) -> u16 {
let g = g as u64;
let mut result = 1;
Expand All @@ -187,12 +186,11 @@ const fn pow_mod_q(g: u16, e: u8) -> u16 {
}


#[allow(clippy::cast_possible_truncation)] // i as u8
const fn gen_zeta_table() -> [u16; 256] {
let mut result = [0u16; 256];
let mut i = 0;
while i < 256u16 {
result[i as usize] = pow_mod_q(ZETA, (i as u8).reverse_bits());
result[i as usize] = pow_mod_q(ZETA, (i.to_le_bytes()[0]).reverse_bits());
i += 1;
}
result
Expand Down
2 changes: 1 addition & 1 deletion src/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ pub(crate) fn sample_ntt(mut byte_stream_b: impl XofReader) -> Result<[Z; 256],
// The proportion of fails is approx 3.098e-12 or 2**{-38}; re-run with fresh randomness.
// See cdf at https://www.wolframalpha.com/input?i=binomial+distribution+calculator&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22x%22%7D+-%3E%22256%22&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22n%22%7D+-%3E%22384%22&assumption=%7B%22F%22%2C+%22BinomialProbabilities%22%2C+%22p%22%7D+-%3E%223329%2F4095%22
// 3: while j < 256 do --> this is adapted for constant-time operation
// #[allow(clippy::cast_possible_truncation)] // mask as u16
for _k in 0..192 {
//
// Note: two samples (d1, d2) are drawn per loop iteration
Expand Down Expand Up @@ -110,6 +109,7 @@ pub(crate) fn sample_poly_cbd(eta: u32, byte_array_b: &[u8]) -> [Z; 256] {
}


// the u types below and above could use a bit more thought
// Count u8 ones in constant time (u32 helps perf)
#[allow(clippy::cast_possible_truncation)] // return res as u16
fn count_ones(x: u32) -> u16 {
Expand Down

0 comments on commit 3834004

Please sign in to comment.