From 232524900ff2ecdefb3bde71f4c6fd10f463046a Mon Sep 17 00:00:00 2001 From: Stanislav Chiknavaryan Date: Mon, 27 Oct 2025 17:39:36 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 824752645 --- shell_wrapper/kahe.cc | 164 +++-- shell_wrapper/kahe.h | 86 ++- shell_wrapper/kahe.rs | 150 +++-- shell_wrapper/kahe_test.cc | 585 +++++++++++------- shell_wrapper/kahe_test.rs | 239 ++++--- shell_wrapper/shell_types.cc | 4 +- shell_wrapper/shell_types.h | 3 +- shell_wrapper/shell_types_test.cc | 8 +- shell_wrapper/single_thread_hkdf.h | 6 +- willow/benches/BUILD | 2 + willow/benches/shell_benchmarks.rs | 53 +- willow/proto/willow/BUILD | 10 + willow/proto/willow/decryptor.proto | 2 +- willow/proto/willow/key.proto | 31 + willow/src/shell/BUILD | 34 +- willow/src/shell/kahe.rs | 398 ++++++------ willow/src/shell/parameters.proto | 20 +- willow/src/shell/parameters.rs | 142 +++-- willow/src/shell/parameters_generation.rs | 26 +- willow/src/shell/parameters_utils.rs | 125 ++++ willow/src/testing_utils/BUILD | 4 + .../testing_utils/shell_testing_parameters.rs | 75 ++- willow/src/testing_utils/testing_utils.rs | 4 +- willow/src/willow_v1/BUILD | 1 + willow/src/willow_v1/client.rs | 63 +- willow/src/zk/rlwe_relation.rs | 78 +-- willow/tests/BUILD | 1 + willow/tests/willow_v1_shell.rs | 177 ++++-- 28 files changed, 1567 insertions(+), 924 deletions(-) create mode 100644 willow/proto/willow/key.proto create mode 100644 willow/src/shell/parameters_utils.rs diff --git a/shell_wrapper/kahe.cc b/shell_wrapper/kahe.cc index 7039090..da0c7d8 100644 --- a/shell_wrapper/kahe.cc +++ b/shell_wrapper/kahe.cc @@ -157,6 +157,7 @@ absl::StatusOr GenerateSecretKey( } namespace internal { + absl::StatusOr EncryptPolynomial( const RnsPolynomial& plaintext, const RnsPolynomial& secret_key, RnsInt plaintext_modulus_rns, int log_n, const RnsPolynomial& a, @@ -196,51 +197,31 @@ absl::StatusOr DecryptPolynomial( return p; } -std::vector> PackMessagesRaw(const Integer* messages, - int num_messages, - Integer packing_base, - int num_packing, - int num_coeffs) { - // NOTE: temporary implementation that copies the input. We can avoid copying - // by modifying the original PackMessages function to work with pointers - // directly. - std::vector messages_copy; - messages_copy.reserve(num_messages); - for (int i = 0; i < num_messages; ++i) { - messages_copy.push_back(messages[i]); - } - return rlwe::PackMessages(messages_copy, packing_base, - num_packing, num_coeffs); -} - -int UnpackMessagesRaw( - const std::vector>& packed_messages, - uint64_t packing_base, int num_packing, int output_values_length, - uint64_t* output_values) { - std::vector unpacked_messages = - rlwe::UnpackMessages(packed_messages, packing_base, num_packing); - - auto count = std::min(static_cast(output_values_length), - unpacked_messages.size()); - std::copy_n(unpacked_messages.begin(), count, output_values); - return count; -} - } // namespace internal absl::StatusOr> EncodeAndEncryptVector( - std::vector>& packed_messages, + const std::vector& packed_values, const RnsPolynomial& secret_key, const KahePublicParameters& params, Prng* prng) { - std::vector ciphertexts; - - if (packed_messages.size() > params.public_polynomials.size()) { - return absl::InvalidArgumentError("Input too long"); + std::vector> plaintexts; + plaintexts.reserve(params.public_polynomials.size()); + int num_coeffs = 1 << params.context->LogN(); + + for (size_t i = 0; i < packed_values.size(); i += num_coeffs) { + size_t chunk_end = std::min(packed_values.size(), i + num_coeffs); + plaintexts.emplace_back(packed_values.begin() + i, + packed_values.begin() + chunk_end); + } + if (plaintexts.size() > params.public_polynomials.size()) { + return absl::InvalidArgumentError("input too long."); } - for (int i = 0; i < packed_messages.size(); ++i) { - const auto& packed_message = packed_messages[i]; + std::vector ciphertexts; + for (int i = 0; i < plaintexts.size(); ++i) { + const auto& packed_message = plaintexts[i]; const RnsPolynomial& a = params.public_polynomials[i]; + // EncodeBgv will pad `packed_message` with zeros to the length of a + // polynomial coefficient vector. SECAGG_ASSIGN_OR_RETURN( RnsPolynomial plaintext, params.encoder.EncodeBgv( @@ -257,10 +238,15 @@ absl::StatusOr> EncodeAndEncryptVector( return ciphertexts; } -absl::StatusOr>> DecodeAndDecryptVector( +absl::StatusOr> DecodeAndDecryptVector( absl::Span ciphertexts, const RnsPolynomial& secret_key, const KahePublicParameters& params) { - std::vector> all_packed_messages; + if (ciphertexts.size() > params.public_polynomials.size()) { + return absl::InvalidArgumentError( + "The size of `ciphertexts` cannot be larger than the size of public " + "polynomials."); + } + std::vector all_packed_messages; for (int i = 0; i < ciphertexts.size(); ++i) { const auto& ciphertext = ciphertexts[i]; const RnsPolynomial& a = params.public_polynomials[i]; @@ -272,7 +258,8 @@ absl::StatusOr>> DecodeAndDecryptVector( params.encoder.DecodeBgv( std::move(plaintext), params.plaintext_modulus, params.moduli, params.modulus_hats, params.modulus_hats_invs)); - all_packed_messages.push_back(std::move(packed_messages)); + all_packed_messages.insert(all_packed_messages.end(), + packed_messages.begin(), packed_messages.end()); } return all_packed_messages; } @@ -326,27 +313,79 @@ FfiStatus GenerateSecretKeyWrapper(const KahePublicParametersWrapper& params, return MakeFfiStatus(); } -FfiStatus Encrypt(rust::Slice input_values, - uint64_t packing_base, uint64_t num_packing, +FfiStatus PackMessagesRaw(rust::Slice messages, + uint64_t packing_base, uint64_t packing_dimension, + uint64_t num_packed_values, + BigIntVectorWrapper* packed_values) { + // Validate the wrappers. + if (packed_values == nullptr) { + return MakeFfiStatus(absl::InvalidArgumentError( + secure_aggregation::kNullPointerErrorMessage)); + } + + // Allocate the vector for output packed values if needed. + if (packed_values->ptr == nullptr) { + packed_values->ptr = + std::make_unique>(); + } + auto curr_packed_values = + rlwe::PackMessagesFlat( + absl::MakeSpan(messages.data(), messages.size()), packing_base, + packing_dimension); + if (curr_packed_values.size() > num_packed_values) { + return MakeFfiStatus(absl::InvalidArgumentError( + "The number of packed values exceeds `num_packed_values`.")); + } + // Pad with zeros if needed. + curr_packed_values.resize(num_packed_values, 0); + // Append the packed values to the end of the output vector. + packed_values->ptr->insert(packed_values->ptr->end(), + curr_packed_values.begin(), + curr_packed_values.end()); + return MakeFfiStatus(); +} + +FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension, + uint64_t num_packed_values, + BigIntVectorWrapper& packed_values, + rust::Vec& out) { + // Validate the wrappers. + if (packed_values.ptr == nullptr) { + return MakeFfiStatus(absl::InvalidArgumentError( + secure_aggregation::kNullPointerErrorMessage)); + } + if (packed_values.ptr->size() < num_packed_values) { + return MakeFfiStatus( + absl::InvalidArgumentError("insufficient number of packed values.")); + } + std::vector unpacked_messages = + rlwe::UnpackMessagesFlat( + absl::MakeSpan(*packed_values.ptr).subspan(0, num_packed_values), + packing_base, packing_dimension); + packed_values.ptr->erase(packed_values.ptr->begin(), + packed_values.ptr->begin() + num_packed_values); + for (auto& val : unpacked_messages) { + out.push_back(val); + } + return MakeFfiStatus(); +} + +FfiStatus Encrypt(const BigIntVectorWrapper& packed_values, const RnsPolynomialWrapper& secret_key, const KahePublicParametersWrapper& params, SingleThreadHkdfWrapper* prng, RnsPolynomialVecWrapper* out) { // Validate the wrappers. - if (secret_key.ptr == nullptr || params.ptr == nullptr || prng == nullptr || - prng->ptr == nullptr || out == nullptr) { + if (packed_values.ptr == nullptr || secret_key.ptr == nullptr || + params.ptr == nullptr || prng == nullptr || prng->ptr == nullptr || + out == nullptr) { return MakeFfiStatus(absl::InvalidArgumentError( secure_aggregation::kNullPointerErrorMessage)); } - // Packing parameters must be valid, e.g. checked on the Rust side. - int num_coeffs = 1 << params.ptr->context->LogN(); - std::vector> packed_messages = - secure_aggregation::internal::PackMessagesRaw( - input_values.data(), input_values.size(), packing_base, num_packing, - num_coeffs); - auto ciphertext_vec = secure_aggregation::EncodeAndEncryptVector( - packed_messages, *secret_key.ptr, *params.ptr, prng->ptr.get()); + *packed_values.ptr, *secret_key.ptr, *params.ptr, prng->ptr.get()); if (!ciphertext_vec.ok()) { return MakeFfiStatus(ciphertext_vec.status()); @@ -357,14 +396,13 @@ FfiStatus Encrypt(rust::Slice input_values, return MakeFfiStatus(); } -FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing, - const RnsPolynomialVecWrapper& ciphertexts, +FfiStatus Decrypt(const RnsPolynomialVecWrapper& ciphertexts, const RnsPolynomialWrapper& secret_key, const KahePublicParametersWrapper& params, - rust::Slice output_values, uint64_t* n_written) { + BigIntVectorWrapper* output_values) { // Validate the wrappers. if (secret_key.ptr == nullptr || params.ptr == nullptr || - ciphertexts.ptr == nullptr || n_written == nullptr) { + ciphertexts.ptr == nullptr || output_values == nullptr) { return MakeFfiStatus(absl::InvalidArgumentError( secure_aggregation::kNullPointerErrorMessage)); } @@ -378,15 +416,13 @@ FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing, } } - auto messages = secure_aggregation::DecodeAndDecryptVector( + auto decrypted_values = secure_aggregation::DecodeAndDecryptVector( *ciphertexts.ptr, *secret_key.ptr, *params.ptr); - if (!messages.ok()) { - return MakeFfiStatus(messages.status()); + if (!decrypted_values.ok()) { + return MakeFfiStatus(decrypted_values.status()); } - - *n_written = secure_aggregation::internal::UnpackMessagesRaw( - messages.value(), packing_base, num_packing, output_values.size(), - output_values.data()); - + output_values->ptr = + std::make_unique>( + std::move(decrypted_values.value())); return MakeFfiStatus(); } diff --git a/shell_wrapper/kahe.h b/shell_wrapper/kahe.h index 926c9d1..dbfac25 100644 --- a/shell_wrapper/kahe.h +++ b/shell_wrapper/kahe.h @@ -35,6 +35,7 @@ namespace secure_aggregation { // Forward-declare types for use by the cxx-generated `kahe.rs.h`. struct KahePublicParameters; +struct BigIntVector; } // namespace secure_aggregation #include "shell_wrapper/kahe.rs.h" @@ -93,36 +94,18 @@ absl::StatusOr DecryptPolynomial( const RnsPolynomial& a, absl::Span* const> moduli); -// Packs messages taken from a raw pointer. -// Expects packing_base > 1, num_packing > 0, num_coeffs > 0, -// packing_base^num_packing < std::numeric_limits::max(). -std::vector> PackMessagesRaw(const uint64_t* messages, - int num_messages, - uint64_t packing_base, - int num_packing, - int num_coeffs); - -// Unpacks messages into a buffer `output_values` of length -// at least `output_values_length`. Returns the elements written to the buffer -// (0 if it didn't write anything). -// Expects packing_base > 1, num_packing > 0, num_coeffs > 0, -// packing_base^num_packing < std::numeric_limits::max(). -int UnpackMessagesRaw( - const std::vector>& packed_messages, - Integer packing_base, int num_packing, int output_values_length, - Integer* output_values); - } // namespace internal -// Encrypts a vector of packed messages, where each coordinate is a vector of -// integers that will be encoded into a single polynomial. +// Encrypts a vector of packed messages, where the packed messages are first +// encoded into plaintext polynomials and then encrypted. absl::StatusOr> EncodeAndEncryptVector( - std::vector>& packed_messages, + const std::vector& packed_values, const RnsPolynomial& secret_key, const KahePublicParameters& params, Prng* prng); -// Decrypts a vector of ciphertexts. -absl::StatusOr>> DecodeAndDecryptVector( +// Decrypts a vector of ciphertexts, and returns the concatenated vector of +// decrypted messages. +absl::StatusOr> DecodeAndDecryptVector( absl::Span ciphertexts, const RnsPolynomial& secret_key, const KahePublicParameters& params); @@ -158,32 +141,47 @@ FfiStatus GenerateSecretKeyWrapper(const KahePublicParametersWrapper& params, SingleThreadHkdfWrapper* prng, RnsPolynomialWrapper* out); -// Packs, encodes and encrypts the messages contained in the `input_values` -// buffer. `packing_base` and `num_packing` are the parameters for packing: the -// encoder takes large modular integers obtained by combining `num_packing` -// smaller uint64_t values, each of which is less than `packing_base`. If -// successful, returns OK and sets *out to a vector of ciphertexts, each of -// which is a polynomial. -FfiStatus Encrypt(rust::Slice input_values, - uint64_t packing_base, uint64_t num_packing, +// Packs `messages` into a vector of BigIntegers using base `packing_base` +// encoding, where the packed values are appended to `packed_values`. +// Expects `packed_values` to be a valid pointer but the underlying vector +// may be unallocated, and expects packing_base > 1, packing_dimension > 0, +// num_coeffs > 0, packing_base^packing_dimension < +// std::numeric_limits::max(). +// Note that `messages` is effectively padded with zeros to the nearest multiple +// of `packing_dimension` before packing. +FfiStatus PackMessagesRaw(rust::Slice messages, + uint64_t packing_base, uint64_t packing_dimension, + uint64_t num_packed_values, + BigIntVectorWrapper* packed_values); + +// Unpacks messages stored at `packed_values[0..num_packed_values]` and appends +// them to `out`, and removes these packed values from `packed_values`. +// Expects `packed_values.ptr` to be a valid pointer to the vector of packed +// values, and expects packing_base > 1, packing_dimension > 0, +// num_packed_values > 0, packing_base^packing_dimension < +// std::numeric_limits::max(). +FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension, + uint64_t num_packed_values, + BigIntVectorWrapper& packed_values, + rust::Vec& out); + +// Encrypts the messages contained in `packed_values`. If successful, returns OK +// and sets *out to a vector of ciphertext polynomials. +// Expects `out` to be a valid pointer but the underlying vector may be +// unallocated. +FfiStatus Encrypt(const BigIntVectorWrapper& packed_values, const RnsPolynomialWrapper& secret_key, const KahePublicParametersWrapper& params, SingleThreadHkdfWrapper* prng, RnsPolynomialVecWrapper* out); -// Decrypts, decodes and unpacks `ciphertexts` into the `output_values` buffer. -// Returns a status, and writes the number of outputs written to `n_written`. -// Each decoded message is a large modular integer, which is then split into -// `num_packing` smaller uint64_t values by decomposing into base -// `packing_base`. Note: Internally, decryption yields num_coeffs * num_packing -// * ciphertexts.size() Integers exactly (decryption does padding), with -// num_coeffs = 1 << params->ptr->context->LogN(). But this function can -// take a smaller `output_values_length` to fill a smaller buffer, e.g. to -// remove padding if the plaintext vector length is known. -FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing, - const RnsPolynomialVecWrapper& ciphertexts, +// Decrypts `ciphertexts` into a vector written to `output_values` buffer, and +// returns a status. +// Expects `output_values` to be a valid pointer but the underlying vector may +// be unallocated. +FfiStatus Decrypt(const RnsPolynomialVecWrapper& ciphertexts, const RnsPolynomialWrapper& secret_key, const KahePublicParametersWrapper& params, - rust::Slice output_values, uint64_t* n_written); + BigIntVectorWrapper* output_values); } // extern "C" diff --git a/shell_wrapper/kahe.rs b/shell_wrapper/kahe.rs index 6c41076..491b50e 100644 --- a/shell_wrapper/kahe.rs +++ b/shell_wrapper/kahe.rs @@ -18,9 +18,17 @@ use shell_types::{Moduli, RnsContextRef, RnsPolynomial, RnsPolynomialVec}; use single_thread_hkdf::{SeedWrapper, SingleThreadHkdfWrapper}; use status::rust_status_from_cpp; +use std::collections::HashMap; use std::marker::PhantomData; use std::mem::MaybeUninit; +#[derive(Debug, PartialEq, Clone)] +pub struct PackedVectorConfig { + pub base: u64, + pub dimension: u64, + pub num_packed_coeffs: u64, +} + #[cxx::bridge] mod ffi { /// Owned KahePublicParameters behind a unique_ptr. @@ -28,6 +36,10 @@ mod ffi { pub ptr: UniquePtr, } + pub struct BigIntVectorWrapper { + pub ptr: UniquePtr>, + } + unsafe extern "C++" { include!("shell_wrapper/kahe.h"); include!("shell_wrapper/shell_types.h"); @@ -35,6 +47,9 @@ mod ffi { #[namespace = "secure_aggregation"] type KahePublicParameters; + #[namespace = "secure_aggregation"] + type BigInteger; + type FfiStatus = shell_types::ffi::FfiStatus; type ModuliWrapper = shell_types::ffi::ModuliWrapper; #[namespace = "secure_aggregation"] @@ -66,10 +81,24 @@ mod ffi { out: *mut RnsPolynomialWrapper, ) -> FfiStatus; - pub unsafe fn Encrypt( - input_values: &[u64], + pub unsafe fn PackMessagesRaw( + messages: &[u64], + packing_base: u64, + packing_dimension: u64, + num_packed_values: u64, + packed_values: *mut BigIntVectorWrapper, + ) -> FfiStatus; + + pub unsafe fn UnpackMessagesRaw( packing_base: u64, - num_packing: u64, + packing_dimension: u64, + num_packed_values: u64, + packed_values: &mut BigIntVectorWrapper, + out: &mut Vec, + ) -> FfiStatus; + + pub unsafe fn Encrypt( + packed_values: &BigIntVectorWrapper, secret_key: &RnsPolynomialWrapper, params: &KahePublicParametersWrapper, prng: *mut SingleThreadHkdfWrapper, @@ -77,13 +106,10 @@ mod ffi { ) -> FfiStatus; pub unsafe fn Decrypt( - packing_base: u64, - num_packing: u64, ciphertexts: &RnsPolynomialVecWrapper, secret_key: &RnsPolynomialWrapper, params: &KahePublicParametersWrapper, - output_values: &mut [u64], - n_written: *mut u64, + output_values: *mut BigIntVectorWrapper, ) -> FfiStatus; } } @@ -152,68 +178,86 @@ pub fn generate_secret_key( Ok(unsafe { out.assume_init() }) } -/// Encrypts a vector of values. -/// -/// The values are encoded as a polynomial, then encrypted with the secret key -/// and the public polynomial at `public_polynomial_index` in `params`. +pub use ffi::BigIntVectorWrapper; + +/// Encrypts the vectors stored in `input_vectors` using `secret_key` and the public polynomials +/// stored in `params`. The input vectors are packed according to the given `packed_vector_configs`. +/// Returns the resulting ciphertexts. pub fn encrypt( - input_values: &[u64], + input_vectors: &HashMap>, + packed_vector_configs: &HashMap, secret_key: &RnsPolynomial, params: &KahePublicParametersWrapper, - packing_base: u64, - num_packing: usize, prng: &mut SingleThreadHkdfWrapper, ) -> Result { + let mut packed_values = MaybeUninit::::zeroed(); + // SAFETY: No lifetime constraints (`PackMessagesRaw` may create a new vector of BigIntegers + // wrapped by `packed_values` which does not keep any reference to the inputs). + // `PackMessagesRaw` only appends to the C++ vector wrapped by `packed_values`, + // allocating it in case it is NULL (in the first iteration). + for (id, packed_vector_config) in packed_vector_configs.iter() { + if !input_vectors.contains_key(id) { + return Err(status::invalid_argument(format!("Input vector with id {} not found", id))); + } + rust_status_from_cpp(unsafe { + ffi::PackMessagesRaw( + &input_vectors[id], + packed_vector_config.base, + packed_vector_config.dimension, + packed_vector_config.num_packed_coeffs, + packed_values.as_mut_ptr(), + ) + })?; + } + let mut out = MaybeUninit::::zeroed(); - // SAFETY: No lifetime constraints (`Encrypt` creates a new polynomial which - // does not keep any reference to the inputs). `Encrypt` reads the - // `input_values` buffer within a valid range. + // SAFETY: No lifetime constraints (`Encrypt` creates a new vector of polynomials wrapped by + // `out` which does not keep any reference to the inputs). `Encrypt` reads the C++ vector + // wrapped by `packed_values`, updates the states wrapped by `prng`, and writes into the C++ + // vector wrapped by `out`. rust_status_from_cpp(unsafe { - ffi::Encrypt( - input_values, - packing_base, - num_packing as u64, - secret_key, - params, - prng, - out.as_mut_ptr(), - ) + ffi::Encrypt(&packed_values.assume_init(), secret_key, params, prng, out.as_mut_ptr()) })?; // SAFETY: `out` is safely initialized if we get to this point. Ok(unsafe { out.assume_init() }) } -/// Decrypts a ciphertext that was encrypted with `secret_key` and the public -/// polynomial a stored at `public_polynomial_index` in `params`. Writes the -/// decrypted values into `output_values`. Returns the number of values written. -/// -/// This low-level API works with slices. The caller can allocate vectors if -/// they want. Using an uninitialized Vec::with_capacity works too, but then we -/// need to manually update the length with `unsafe { -/// output_values.set_len(n_messages_written) }` because Rust doesn't know that -/// C has written into the vector. +/// Decrypts ciphertexts that were encrypted with `secret_key` and the public polynomials stored +/// in `params`. Returns the unpacked decrypted values. +/// The decrypted values are unpacked according to the given `packed_vector_configs`. pub fn decrypt( ciphertext: &RnsPolynomialVec, secret_key: &RnsPolynomial, params: &KahePublicParametersWrapper, - packing_base: u64, - num_packing: usize, - output_values: &mut [u64], -) -> Result { - // SAFETY: No lifetime constraints (`DecryptionResult` just holds two ints and - // does not keep any reference to the inputs). `Decrypt` only modifies the - // `output_values` buffer within a valid range. - let mut n_written = 0u64; + packed_vector_configs: &HashMap, +) -> Result>, status::StatusError> { + let mut packed_values = MaybeUninit::::zeroed(); + // SAFETY: No lifetime constraints (`packed_values` does not keep any reference to the inputs). + // `Decrypt` creates a new C++ vector wrapped by `output_values` and only modifies this buffer. rust_status_from_cpp(unsafe { - ffi::Decrypt( - packing_base, - num_packing as u64, - ciphertext, - secret_key, - params, - output_values, - &mut n_written, - ) + ffi::Decrypt(ciphertext, secret_key, params, packed_values.as_mut_ptr()) })?; - Ok(n_written as usize) + + let mut output_vectors = HashMap::>::new(); + // Assume the packed values are stored in the same order as the configs. + for (id, packed_vector_config) in packed_vector_configs.iter() { + let unpacked_size = + (packed_vector_config.num_packed_coeffs * packed_vector_config.dimension) as usize; + let mut unpacked_values = Vec::with_capacity(unpacked_size); + /// SAFETY: No lifetime constraints (output values of `UnpackMessagesRaw` do not keep any + /// reference to its inputs). `UnpackMessagesRaw` reads and removes a prefix of the C++ + /// vector wrapped by `packed_values`, and writes into the buffer `out`. + rust_status_from_cpp(unsafe { + ffi::UnpackMessagesRaw( + packed_vector_config.base, + packed_vector_config.dimension, + packed_vector_config.num_packed_coeffs, + packed_values.assume_init_mut(), + &mut unpacked_values, + ) + })?; + output_vectors.insert(id.clone(), unpacked_values); + } + + Ok(output_vectors) } diff --git a/shell_wrapper/kahe_test.cc b/shell_wrapper/kahe_test.cc index 2ff7c21..bcfae7b 100644 --- a/shell_wrapper/kahe_test.cc +++ b/shell_wrapper/kahe_test.cc @@ -17,12 +17,12 @@ #include #include -#include #include #include #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -46,12 +46,11 @@ namespace secure_aggregation { namespace { using secure_aggregation::secagg_internal::StatusIs; -using ::testing::IsNull; constexpr int kLogN = 12; constexpr int kNumCoeffs = 1 << kLogN; -const std::vector kQs = {1125899906826241ULL, - 1125899906629633ULL}; // q ~ 2^100 +const std::vector kQs = {1125899906826241ULL, + 1125899906629633ULL}; // q ~ 2^100 // We need t * e in [-q/2, q/2). // We take kLogT < 100 - 1 - log2(kTailBoundMultiplier) - log2(kPrgErrorS) @@ -64,8 +63,8 @@ const RnsContextConfig kRnsContextConfig = { .t = 2, // Dummy RNS plaintext modulus here }; -rust::Slice ToRustSlice(absl::Span s) { - return rust::Slice(s.data(), s.size()); +rust::Slice ToRustSlice(absl::Span s) { + return rust::Slice(s.data(), s.size()); } using ::ToRustSlice; // Import into namespace for correct resolution. @@ -231,12 +230,12 @@ TEST(KaheTest, VectorEncryptDecrypt) { // Encrypt random input vector that uses all the polynomial coefficients. constexpr int num_polynomials = 10; - std::vector> all_packed_messages; - std::vector packed_messages; - packed_messages.reserve(kNumCoeffs); + std::vector all_packed_messages; + all_packed_messages.reserve(kNumCoeffs * num_polynomials); for (int i = 0; i < num_polynomials; ++i) { auto packed_messages = testing::SampleUint256Messages(kNumCoeffs, kT); - all_packed_messages.push_back(std::move(packed_messages)); + all_packed_messages.insert(all_packed_messages.end(), + packed_messages.begin(), packed_messages.end()); } SECAGG_ASSERT_OK_AND_ASSIGN( @@ -249,11 +248,112 @@ TEST(KaheTest, VectorEncryptDecrypt) { EXPECT_EQ(all_packed_messages, decrypted); } +TEST(KaheTest, PackMessagesRawAllocatesOutputVectorIfNull) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 1; + constexpr int num_packed_values = 10; + std::vector messages = + rlwe::testing::SampleMessages(num_packed_values, packing_base); + // Create a wrapper with an unallocated vector. + BigIntVectorWrapper packed_values{.ptr = nullptr}; + SECAGG_EXPECT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(messages), packing_base, packing_dimension, + num_packed_values, &packed_values))); + ASSERT_NE(packed_values.ptr, nullptr); + EXPECT_EQ(packed_values.ptr->size(), num_packed_values); + EXPECT_EQ(*packed_values.ptr, + (rlwe::PackMessagesFlat(messages, packing_base, + packing_dimension))); +} + +TEST(KaheTest, PackMessagesRawPadsWithZeros) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 3; + constexpr int num_messages = 5; + constexpr int num_packed_values = 10; + + std::vector messages = + rlwe::testing::SampleMessages(num_messages, packing_base); + BigIntVectorWrapper packed_values{.ptr = nullptr}; + SECAGG_EXPECT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(messages), packing_base, packing_dimension, + num_packed_values, &packed_values))); + EXPECT_EQ(packed_values.ptr->size(), num_packed_values); + + // Check that the prefix of the packed values match the expected packed + // values. + std::vector expected_packed_values = + rlwe::PackMessagesFlat(messages, packing_base, + packing_dimension); + ASSERT_LT(expected_packed_values.size(), packed_values.ptr->size()); + EXPECT_EQ( + absl::MakeSpan(*packed_values.ptr).first(expected_packed_values.size()), + expected_packed_values); + + // The suffix should be padded with zeros. + EXPECT_THAT( + absl::MakeSpan(*packed_values.ptr).subspan(expected_packed_values.size()), + ::testing::Each(::testing::Eq(0))); +} + +TEST(KaheTest, PackMessagesRawAppendsPackedValues) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 1; + constexpr int num_packed_values = 10; + constexpr BigInteger kT = 65537; + + // Create a wrapper with a vector of already packed values. + std::vector already_packed_values = + testing::SampleUint256Messages(num_packed_values, kT); + BigIntVectorWrapper packed_values{ + .ptr = std::make_unique>(already_packed_values)}; + + // Pack more values and check that they are appended to the existing vector. + std::vector messages = + rlwe::testing::SampleMessages(num_packed_values, packing_base); + SECAGG_EXPECT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(messages), packing_base, packing_dimension, + num_packed_values, &packed_values))); + EXPECT_EQ(packed_values.ptr->size(), num_packed_values * 2); + EXPECT_EQ(absl::MakeSpan(*packed_values.ptr).first(num_packed_values), + already_packed_values); + EXPECT_EQ(absl::MakeSpan(*packed_values.ptr).last(num_packed_values), + (rlwe::PackMessagesFlat(messages, packing_base, + packing_dimension))); +} + +TEST(KaheTest, UnpackMessagesRawRemovesConsumedPackedValues) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 1; + constexpr int num_packed_values = 10; + // Since packing_dimension == 1, `packed` is the same as unpacked messages. + std::vector packed = + testing::SampleUint256Messages(num_packed_values * 2, packing_base); + BigIntVectorWrapper packed_values{ + .ptr = std::make_unique>(packed)}; + + // Unpack `num_packed_values` messages, which should remove the first + // `num_packed_values` elements from the vector in `packed_values`. + rust::Vec unpacked_messages; + SECAGG_EXPECT_OK(UnwrapFfiStatus( + UnpackMessagesRaw(packing_base, packing_dimension, num_packed_values, + packed_values, unpacked_messages))); + EXPECT_EQ(packed_values.ptr->size(), num_packed_values); + EXPECT_EQ(unpacked_messages.size(), num_packed_values); + // Unpacked values should match the first half of the original packed values. + for (int i = 0; i < num_packed_values; ++i) { + EXPECT_EQ(unpacked_messages[i], static_cast(packed[i])); + } + // Check that the remaining packed values are unchanged. + EXPECT_EQ(absl::MakeSpan(*packed_values.ptr).first(num_packed_values), + absl::MakeSpan(packed).subspan(num_packed_values)); +} + TEST(KaheTest, PackAndEncrypt) { constexpr int num_packing = 8; constexpr int num_public_polynomials = 2; constexpr int num_messages = 30; - constexpr uint64_t packing_base = 2; + constexpr Integer packing_base = 2; SECAGG_ASSERT_OK_AND_ASSIGN(std::string public_seed, Prng::GenerateSeed()); SECAGG_ASSERT_OK_AND_ASSIGN( @@ -264,264 +364,236 @@ TEST(KaheTest, PackAndEncrypt) { SECAGG_ASSERT_OK_AND_ASSIGN(auto prng, Prng::Create(seed)); SECAGG_ASSERT_OK_AND_ASSIGN(auto key, GenerateSecretKey(params, prng.get())); - std::vector input_vec = + std::vector input_messages = rlwe::testing::SampleMessages(num_messages, packing_base); + std::vector packed_messages = + rlwe::PackMessagesFlat(input_messages, packing_base, + num_packing); + // packed_messages length should be ceil(num_messages / num_packing). + int num_packed_messages = (num_messages + num_packing - 1) / num_packing; + EXPECT_EQ(packed_messages.size(), num_packed_messages); + + // Check that PackMessagesRaw works as expected. + BigIntVectorWrapper raw_packed_messages_wrapper{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(input_messages), packing_base, num_packing, + num_packed_messages, &raw_packed_messages_wrapper))); + EXPECT_EQ(*raw_packed_messages_wrapper.ptr, packed_messages); - std::vector> packed_messages = - rlwe::PackMessages(input_vec, packing_base, - num_packing, kNumCoeffs); - EXPECT_EQ(packed_messages.size(), - 1); // Only one polynomial needed. - - // Check that RawPack works as expected. - int num_coeffs = 1 << params.context->LogN(); - EXPECT_EQ(num_coeffs, kNumCoeffs); - - std::vector> raw_packed_messages = - secure_aggregation::internal::PackMessagesRaw( - input_vec.data(), input_vec.size(), packing_base, num_packing, - num_coeffs); - EXPECT_EQ(raw_packed_messages, packed_messages); - + // Encrypt the packed messages. SECAGG_ASSERT_OK_AND_ASSIGN(auto ciphertexts, secure_aggregation::EncodeAndEncryptVector( packed_messages, key, params, prng.get())); - + EXPECT_EQ(ciphertexts.size(), 1); // Only one ciphertext polynomial needed. SECAGG_ASSERT_OK_AND_ASSIGN( auto decrypted, secure_aggregation::DecodeAndDecryptVector(ciphertexts, key, params)); - EXPECT_EQ(decrypted.size(), 1); // Only one ciphertext polynomial needed. - - std::vector unpacked_messages = - rlwe::UnpackMessages(decrypted, packing_base, num_packing); - - // Check that UnpackRaw works as expected. - uint64_t decrypted_buffer[2 * num_messages]; - auto n_messages_written = secure_aggregation::internal::UnpackMessagesRaw( - decrypted, packing_base, num_packing, 2 * num_messages, decrypted_buffer); - EXPECT_EQ(n_messages_written, 2 * num_messages); - EXPECT_EQ(absl::MakeSpan(decrypted_buffer, num_messages), - absl::MakeSpan(unpacked_messages).subspan(0, num_messages)); - - // Check decrypted messages and padding. - EXPECT_EQ(absl::MakeSpan(unpacked_messages).subspan(0, num_messages), - absl::MakeSpan(input_vec).subspan(0, num_messages)); + EXPECT_EQ(decrypted.size(), kNumCoeffs); + + // Check that UnpackMessagesRaw works as expected. + std::vector expected_unpacked_messages = + rlwe::UnpackMessagesFlat(decrypted, packing_base, + num_packing); + BigIntVectorWrapper decrypted_wrapper{ + .ptr = std::make_unique>(std::move(decrypted))}; + rust::Vec unpacked_messages; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + UnpackMessagesRaw(packing_base, num_packing, packed_messages.size(), + decrypted_wrapper, unpacked_messages))); + EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages), + absl::MakeSpan(expected_unpacked_messages.data(), num_messages)); + // Check against the original input messages. + EXPECT_EQ(absl::MakeSpan(unpacked_messages.data(), num_messages), + absl::MakeSpan(input_messages).subspan(0, num_messages)); + // Check unpacked messages are padded with zeros. + ASSERT_GE(expected_unpacked_messages.size(), num_messages); EXPECT_THAT( - absl::MakeSpan(unpacked_messages) + absl::MakeSpan(unpacked_messages.data(), unpacked_messages.size()) .subspan(num_messages, unpacked_messages.size() - num_messages), ::testing::Each(::testing::Eq(0))); } TEST(KaheTest, RawVectorEncryptOnePolynomial) { constexpr int num_packing = 2; - constexpr uint64_t num_public_polynomials = 2; - FfiStatus status; + constexpr int num_public_polynomials = 2; + constexpr int num_messages = 10; + constexpr Integer packing_base = 10; + std::unique_ptr public_seed; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); RnsPolynomialWrapper key; SECAGG_ASSERT_OK( UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key))); // Generate random messages that fit on one polynomial. - constexpr int num_messages = 10; - constexpr uint64_t packing_base = 10; - std::vector input_values = + std::vector input_messages = rlwe::testing::SampleMessages(num_messages, packing_base); + BigIntVectorWrapper packed_messages_wrapper{ + .ptr = std::make_unique>()}; + int num_packed_messages = (num_messages + num_packing - 1) / num_packing; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(input_messages), packing_base, num_packing, + num_packed_messages, &packed_messages_wrapper))); + // packed_messages length should be ceil(num_messages / num_packing). + EXPECT_EQ(packed_messages_wrapper.ptr->size(), num_packed_messages); + RnsPolynomialVecWrapper ciphertexts; - SECAGG_ASSERT_OK( - UnwrapFfiStatus(Encrypt(ToRustSlice(input_values), packing_base, - num_packing, key, params, &prng, &ciphertexts))); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + Encrypt(packed_messages_wrapper, key, params, &prng, &ciphertexts))); // Check that decryption works when we decrypt only what we need. - uint64_t decrypted[num_messages]; - uint64_t n_written; - SECAGG_ASSERT_OK(UnwrapFfiStatus( - Decrypt(packing_base, num_packing, ciphertexts, key, params, - rust::Slice(decrypted, num_messages), &n_written))); + BigIntVectorWrapper decrypted_wrapper{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK( + UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &decrypted_wrapper))); + + rust::Vec unpacked_decrypted_messages; + SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, num_packing, decrypted_wrapper.ptr->size(), + decrypted_wrapper, unpacked_decrypted_messages))); // Filled the whole buffer with right messages. - EXPECT_EQ(num_messages, n_written); - EXPECT_EQ(absl::MakeSpan(decrypted), absl::MakeSpan(input_values)); + EXPECT_EQ(absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages), + absl::MakeSpan(input_messages)); // Check that decryption still work when we receive some padding. - constexpr uint64_t buffer_length = + constexpr int buffer_length = 2 * kNumCoeffs * num_packing; // Room for 2 plaintext polynomials. - constexpr uint64_t padded_length = + constexpr int padded_length = kNumCoeffs * num_packing; // What the padded input really needs. - uint64_t decrypted_long[buffer_length] = {}; - decrypted_long[padded_length - 1] = 42; // Check that we overwrite this. - decrypted_long[padded_length] = 42; // Check that we don't overwrite this. - n_written = 0; + BigIntVectorWrapper decrypted_long_messages_wrapper{ + .ptr = std::make_unique>()}; SECAGG_ASSERT_OK(UnwrapFfiStatus( - Decrypt(packing_base, num_packing, ciphertexts, key, params, - rust::Slice(decrypted_long, buffer_length), &n_written))); + Decrypt(ciphertexts, key, params, &decrypted_long_messages_wrapper))); - // Decrypt doesn't fill the whole buffer. - EXPECT_EQ(n_written, padded_length); + rust::Vec unpacked_decrypted_long_messages; + unpacked_decrypted_long_messages.reserve(buffer_length); + SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, num_packing, decrypted_long_messages_wrapper.ptr->size(), + decrypted_long_messages_wrapper, unpacked_decrypted_long_messages))); // The non-zero messages are identical. - EXPECT_EQ(absl::MakeSpan(decrypted_long).subspan(0, num_messages), - absl::MakeSpan(input_values)); + EXPECT_EQ( + absl::MakeSpan(unpacked_decrypted_long_messages.data(), num_messages), + absl::MakeSpan(input_messages)); // Decrypted messages are padded to zero up to the end of the polynomial. - EXPECT_THAT(absl::MakeSpan(decrypted_long) + EXPECT_THAT(absl::MakeSpan(unpacked_decrypted_long_messages.data(), + unpacked_decrypted_long_messages.size()) .subspan(num_messages, padded_length - num_messages), ::testing::Each(::testing::Eq(0))); - - // The canary is unchanged. - EXPECT_EQ(decrypted_long[padded_length], 42); } TEST(KaheTest, RawVectorEncryptTwoPolynomials) { constexpr int num_packing = 8; - constexpr uint64_t num_public_polynomials = 2; - FfiStatus status; + constexpr int num_public_polynomials = 2; + std::unique_ptr public_seed; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); RnsPolynomialWrapper key; SECAGG_ASSERT_OK( UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key))); // Generate random messages that need two polynomials. constexpr int num_messages = kNumCoeffs * num_packing + 10; - constexpr uint64_t packing_base = 2; - std::vector input_vec = + constexpr int num_packed_messages = + (num_messages + num_packing - 1) / num_packing; + constexpr Integer packing_base = 2; + std::vector input_messages = rlwe::testing::SampleMessages(num_messages, packing_base); - uint64_t input_values[num_messages]; - for (int i = 0; i < num_messages; ++i) { - input_values[i] = input_vec[i]; - } - RnsPolynomialVecWrapper ciphertexts; - SECAGG_ASSERT_OK( - UnwrapFfiStatus(Encrypt(ToRustSlice(input_values), packing_base, - num_packing, key, params, &prng, &ciphertexts))); - - // Check that decryption works when we decrypt only what we need. - uint64_t decrypted[num_messages]; - uint64_t n_written; + BigIntVectorWrapper packed_messages_wrapper{ + .ptr = std::make_unique>()}; SECAGG_ASSERT_OK(UnwrapFfiStatus( - Decrypt(packing_base, num_packing, ciphertexts, key, params, - rust::Slice(decrypted, num_messages), &n_written))); - - EXPECT_EQ(n_written, num_messages); - - for (int i = 0; i < num_messages; ++i) { - EXPECT_EQ(input_values[i], decrypted[i]); - } - - // Check that decryption is padded properly. - constexpr int num_ciphertext_polynomials = - 2; // Input fits on two polynomials. - constexpr int padded_length = - kNumCoeffs * num_packing * num_ciphertext_polynomials; - constexpr int buffer_length = padded_length * 2; - uint64_t decrypted_padded[buffer_length]; + PackMessagesRaw(ToRustSlice(input_messages), packing_base, num_packing, + num_packed_messages, &packed_messages_wrapper))); + RnsPolynomialVecWrapper ciphertexts; SECAGG_ASSERT_OK(UnwrapFfiStatus( - Decrypt(packing_base, num_packing, ciphertexts, key, params, - rust::Slice(decrypted_padded, buffer_length), &n_written))); - - EXPECT_EQ(n_written, padded_length); - for (int i = 0; i < num_messages; ++i) { - EXPECT_EQ(input_values[i], decrypted[i]); - } - for (int i = num_messages; i < padded_length; ++i) { - EXPECT_EQ(decrypted_padded[i], 0); - } + Encrypt(packed_messages_wrapper, key, params, &prng, &ciphertexts))); - // Check that the padding is not too long. - constexpr int wrong_num_ciphertext_polynomials = 3; - constexpr int wrong_padded_length = - kNumCoeffs * num_packing * wrong_num_ciphertext_polynomials; - constexpr int wrong_buffer_length = wrong_padded_length * 2; - uint64_t wrong_decrypted_padded[wrong_buffer_length]; - SECAGG_ASSERT_OK(UnwrapFfiStatus(Decrypt( - packing_base, num_packing, ciphertexts, key, params, - rust::Slice(wrong_decrypted_padded, wrong_buffer_length), &n_written))); - - EXPECT_NE(n_written, wrong_padded_length); + // Check that decryption works when we decrypt only what we need. + BigIntVectorWrapper decrypted_wrapper{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK( + UnwrapFfiStatus(Decrypt(ciphertexts, key, params, &decrypted_wrapper))); + rust::Vec unpacked_decrypted_messages; + SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, num_packing, decrypted_wrapper.ptr->size(), + decrypted_wrapper, unpacked_decrypted_messages))); + + EXPECT_GE(unpacked_decrypted_messages.size(), num_messages); + EXPECT_EQ(absl::MakeSpan(input_messages), + absl::MakeSpan(unpacked_decrypted_messages.data(), num_messages)); } TEST(KaheTest, Failures) { constexpr int num_packing = 8; - constexpr uint64_t num_public_polynomials = 2; - FfiStatus status; + constexpr int num_public_polynomials = 2; + std::unique_ptr public_seed; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms))); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); RnsPolynomialWrapper key; SECAGG_ASSERT_OK( UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key))); // Generate random messages that need 3 polynomials. constexpr int num_messages = kNumCoeffs * num_packing * 3; - constexpr uint64_t packing_base = 2; - std::vector input_vec = + constexpr int num_packed_messages = + (num_messages + num_packing - 1) / num_packing; + constexpr Integer packing_base = 2; + std::vector input_messages = rlwe::testing::SampleMessages(num_messages, packing_base); - uint64_t input_values[num_messages]; - for (int i = 0; i < num_messages; ++i) { - input_values[i] = input_vec[i]; - } // Check that encryption fails if we don't have enough public polynomials. - RnsPolynomialVecWrapper vec_out; - EXPECT_THAT( - UnwrapFfiStatus(Encrypt(ToRustSlice(input_values), packing_base, - num_packing, key, params, &prng, &vec_out)), - StatusIs(absl::StatusCode::kInvalidArgument)); + BigIntVectorWrapper packed_messages_wrapper{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(input_messages), packing_base, num_packing, + num_packed_messages, &packed_messages_wrapper))); + RnsPolynomialVecWrapper ciphertexts; + EXPECT_THAT(UnwrapFfiStatus(Encrypt(packed_messages_wrapper, key, params, + &prng, &ciphertexts)), + StatusIs(absl::StatusCode::kInvalidArgument)); // Check failures on invalid pointers or wrappers KahePublicParametersWrapper bad_params = {.ptr = nullptr}; - EXPECT_THAT(UnwrapFfiStatus(Encrypt( - rust::Slice(input_values, - 2 * kNumCoeffs * num_packing), - packing_base, num_packing, key, bad_params, &prng, &vec_out)), + EXPECT_THAT(UnwrapFfiStatus(Encrypt(packed_messages_wrapper, key, bad_params, + &prng, &ciphertexts)), StatusIs(absl::StatusCode::kInvalidArgument)); - constexpr uint64_t buffer_length = - 2 * kNumCoeffs * num_packing; // Room for 2 plaintext polynomials. - uint64_t decrypted_long[buffer_length] = {}; - uint64_t n_written; - vec_out.ptr = nullptr; - EXPECT_THAT(UnwrapFfiStatus(Decrypt( - packing_base, num_packing, vec_out, key, params, - rust::Slice(decrypted_long, buffer_length), &n_written)), + RnsPolynomialVecWrapper bad_ciphertexts{.len = 1, .ptr = nullptr}; + BigIntVectorWrapper decrypted_wrapper; + EXPECT_THAT(UnwrapFfiStatus( + Decrypt(bad_ciphertexts, key, params, &decrypted_wrapper)), StatusIs(absl::StatusCode::kInvalidArgument)); // Also check keygen and parameters. @@ -531,79 +603,146 @@ TEST(KaheTest, Failures) { StatusIs(absl::StatusCode::kInvalidArgument)); EXPECT_THAT(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), nullptr)), + ToRustSlice(*public_seed), /*out=*/nullptr)), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(KaheTest, PackMessagesRawFailsIfNullOutputWrapper) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 1; + constexpr int num_messages = 10; + std::vector input_messages = + rlwe::testing::SampleMessages(num_messages, packing_base); + EXPECT_THAT(UnwrapFfiStatus(PackMessagesRaw( + ToRustSlice(input_messages), packing_base, packing_dimension, + num_messages, /*packed_values=*/nullptr)), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(KaheTest, PackMessagesRawFailsIfInputTooLong) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 1; + constexpr int num_packed_messages = 10; + constexpr int bad_num_messages = num_packed_messages * packing_dimension + 1; + std::vector bad_input_messages = + rlwe::testing::SampleMessages(bad_num_messages, packing_base); + BigIntVectorWrapper packed_messages_wrapper{ + .ptr = std::make_unique>()}; + EXPECT_THAT( + UnwrapFfiStatus(PackMessagesRaw( + ToRustSlice(bad_input_messages), packing_base, packing_dimension, + num_packed_messages, &packed_messages_wrapper)), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(KaheTest, UnpackMessagesRawFailsIfUnallocatedPackedValues) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 1; + constexpr int num_packed_messages = 10; + BigIntVectorWrapper bad_packed_values{.ptr = nullptr}; + rust::Vec unpacked_messages; + EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, packing_dimension, num_packed_messages, + bad_packed_values, unpacked_messages)), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(KaheTest, UnpackMessagesRawFailsIfPackedValuesTooShort) { + constexpr Integer packing_base = 10; + constexpr int packing_dimension = 1; + constexpr int num_packed_messages = 10; + // A wrapper with a packed message vector that is shorter than expected. + BigIntVectorWrapper bad_packed_values{ + .ptr = std::make_unique>(num_packed_messages - 1, + 0)}; + rust::Vec unpacked_messages; + EXPECT_THAT(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, packing_dimension, num_packed_messages, + bad_packed_values, unpacked_messages)), StatusIs(absl::StatusCode::kInvalidArgument)); } TEST(KaheTest, AddInPlacePolynomial) { - constexpr uint64_t num_public_polynomials = 1; - FfiStatus status; + constexpr int num_public_polynomials = 1; + std::unique_ptr public_seed; - status = GenerateSingleThreadHkdfSeed(public_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(public_seed))); KahePublicParametersWrapper params; - status = CreateKahePublicParametersWrapper( + SECAGG_ASSERT_OK(UnwrapFfiStatus(CreateKahePublicParametersWrapper( kLogN, kLogT, ToRustSlice(kQs), num_public_polynomials, - ToRustSlice(*public_seed), ¶ms); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + ToRustSlice(*public_seed), ¶ms))); auto moduli = CreateModuliWrapperFromKaheParams(params); std::unique_ptr private_seed; - status = GenerateSingleThreadHkdfSeed(private_seed); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(GenerateSingleThreadHkdfSeed(private_seed))); SingleThreadHkdfWrapper prng; - status = CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng); - SECAGG_ASSERT_OK(UnwrapFfiStatus(status)); + SECAGG_ASSERT_OK(UnwrapFfiStatus( + CreateSingleThreadHkdf(ToRustSlice(*private_seed), prng))); // Generate two keys. - RnsPolynomialWrapper key_1; + RnsPolynomialWrapper key1; SECAGG_ASSERT_OK( - UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key_1))); - RnsPolynomialWrapper key_2; + UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key1))); + RnsPolynomialWrapper key2; SECAGG_ASSERT_OK( - UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key_2))); + UnwrapFfiStatus(GenerateSecretKeyWrapper(params, &prng, &key2))); // Sample two messages and encrypt them. constexpr int num_messages = 10; - constexpr uint64_t packing_base = 10; - constexpr uint64_t input_domain = + constexpr Integer packing_base = 10; + constexpr Integer input_domain = packing_base / 2; // 2 inputs should fit in the base. constexpr int num_packing = 3; - std::vector input_values_1 = + constexpr int num_packed_messages = + (num_messages + num_packing - 1) / num_packing; + std::vector input_values1 = rlwe::testing::SampleMessages(num_messages, input_domain); - RnsPolynomialVecWrapper ciphertexts_1; - SECAGG_ASSERT_OK(UnwrapFfiStatus(Encrypt(ToRustSlice(input_values_1), - packing_base, num_packing, key_1, - params, &prng, &ciphertexts_1))); - std::vector input_values_2 = + BigIntVectorWrapper packed_messages_wrapper1{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(input_values1), packing_base, num_packing, + num_packed_messages, &packed_messages_wrapper1))); + RnsPolynomialVecWrapper ciphertexts1; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + Encrypt(packed_messages_wrapper1, key1, params, &prng, &ciphertexts1))); + std::vector input_values2 = rlwe::testing::SampleMessages(num_messages, input_domain); - RnsPolynomialVecWrapper ciphertexts_2; - SECAGG_ASSERT_OK(UnwrapFfiStatus(Encrypt(ToRustSlice(input_values_2), - packing_base, num_packing, key_2, - params, &prng, &ciphertexts_2))); + BigIntVectorWrapper packed_messages_wrapper2{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + PackMessagesRaw(ToRustSlice(input_values2), packing_base, num_packing, + num_packed_messages, &packed_messages_wrapper2))); + RnsPolynomialVecWrapper ciphertexts2; + SECAGG_ASSERT_OK(UnwrapFfiStatus( + Encrypt(packed_messages_wrapper2, key2, params, &prng, &ciphertexts2))); // Check that we can add keys (single polynomials) correctly. SECAGG_ASSERT_OK_AND_ASSIGN(auto manual_sum_copy, - key_1.ptr->Add(*key_2.ptr, params.ptr->moduli)); - SECAGG_ASSERT_OK(UnwrapFfiStatus(AddInPlace(moduli, &key_1, &key_2))); - ASSERT_EQ(manual_sum_copy, *key_2.ptr); + key1.ptr->Add(*key2.ptr, params.ptr->moduli)); + SECAGG_ASSERT_OK(UnwrapFfiStatus(AddInPlace(moduli, &key1, &key2))); + ASSERT_EQ(manual_sum_copy, *key2.ptr); // Check that we can add vectors of polynomials. SECAGG_ASSERT_OK_AND_ASSIGN( - manual_sum_copy, ciphertexts_1.ptr->at(0).Add(ciphertexts_2.ptr->at(0), - params.ptr->moduli)); + manual_sum_copy, + ciphertexts1.ptr->at(0).Add(ciphertexts2.ptr->at(0), params.ptr->moduli)); SECAGG_ASSERT_OK( - UnwrapFfiStatus(AddInPlaceVec(moduli, &ciphertexts_1, &ciphertexts_2))); - ASSERT_EQ(manual_sum_copy, ciphertexts_2.ptr->at(0)); + UnwrapFfiStatus(AddInPlaceVec(moduli, &ciphertexts1, &ciphertexts2))); + ASSERT_EQ(manual_sum_copy, ciphertexts2.ptr->at(0)); // Check homomorphism. - uint64_t decrypted[num_messages]; - uint64_t n_written; - SECAGG_ASSERT_OK(UnwrapFfiStatus( - Decrypt(packing_base, num_packing, ciphertexts_2, key_2, params, - rust::Slice(decrypted, num_messages), &n_written))); + BigIntVectorWrapper decrypted_wrapper{ + .ptr = std::make_unique>()}; + SECAGG_ASSERT_OK( + UnwrapFfiStatus(Decrypt(ciphertexts2, key2, params, &decrypted_wrapper))); + rust::Vec unpacked_decrypted_messages; + unpacked_decrypted_messages.reserve(num_messages); + SECAGG_ASSERT_OK(UnwrapFfiStatus(UnpackMessagesRaw( + packing_base, num_packing, decrypted_wrapper.ptr->size(), + decrypted_wrapper, unpacked_decrypted_messages))); for (int i = 0; i < num_messages; ++i) { - EXPECT_EQ(input_values_1[i] + input_values_2[i], decrypted[i]); + EXPECT_EQ(input_values1[i] + input_values2[i], + unpacked_decrypted_messages[i]); } } diff --git a/shell_wrapper/kahe_test.rs b/shell_wrapper/kahe_test.rs index 3e2d52f..d7cca4f 100644 --- a/shell_wrapper/kahe_test.rs +++ b/shell_wrapper/kahe_test.rs @@ -14,13 +14,14 @@ use googletest::{ expect_that, fail, gtest, - matchers::{container_eq, eq}, - Result, + matchers::{container_eq, eq, gt}, + verify_that, Result, }; -use kahe::{create_public_parameters, decrypt, encrypt, generate_secret_key}; +use kahe::{create_public_parameters, decrypt, encrypt, generate_secret_key, PackedVectorConfig}; use rand::Rng; use status::StatusErrorCode; use status_matchers_rs::status_is; +use std::collections::HashMap; // RNS configuration. LOG_T is the bit length of the KAHE plaintext modulus. const LOG_T: u64 = 11; @@ -29,6 +30,8 @@ const QS: [u64; 2] = [1125899906826241, 1125899906629633]; #[gtest] fn encrypt_decrypt() -> Result<()> { + const DEFAULT_ID: &str = "default"; + // Generate public parameters. let public_seed = single_thread_hkdf::generate_seed()?; let num_public_polynomials = 1; @@ -40,26 +43,24 @@ fn encrypt_decrypt() -> Result<()> { let secret_key = generate_secret_key(¶ms, &mut prng)?; // Encrypt small vector. `ciphertext` is a wrapper around a C++ pointer. - let input_domain = 10; - let num_packing = 2; let input_values = vec![1, 2, 3]; - let ciphertext = - encrypt(&input_values, &secret_key, ¶ms, input_domain, num_packing, &mut prng)?; - - // Allocate a small buffer, and decrypt into it. - let output_values_length = 3; - let mut output_values = vec![0; output_values_length]; - let n_written = - decrypt(&ciphertext, &secret_key, ¶ms, input_domain, num_packing, &mut output_values)?; - - expect_that!(n_written, eq(3)); - expect_that!(output_values, container_eq(input_values)); - + let plaintext = HashMap::from([(String::from(DEFAULT_ID), input_values.clone())]); + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 2 }, + )]); + let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; + + let output_values = decrypt(&ciphertext, &secret_key, ¶ms, &packed_vector_configs)?; + expect_that!(output_values.contains_key(DEFAULT_ID), eq(true)); + expect_that!(output_values[DEFAULT_ID][..3], container_eq(input_values)); Ok(()) } #[gtest] fn encrypt_decrypt_padding() -> Result<()> { + const DEFAULT_ID: &str = "default"; + // Generate public parameters and secret key. let public_seed = single_thread_hkdf::generate_seed()?; let num_public_polynomials = 1; @@ -68,38 +69,41 @@ fn encrypt_decrypt_padding() -> Result<()> { let mut prng = single_thread_hkdf::create(&seed)?; let secret_key = generate_secret_key(¶ms, &mut prng)?; - // Generate a short random vector, encrypt and decrypt it. - let num_messages = 40; + // Generate a short random vector. + let num_input_values = 40; let input_domain = 10; - let num_packing = 2; - let mut input_values: Vec = Vec::with_capacity(num_messages); - for _ in 0..num_messages { - input_values.push(rand::thread_rng().gen_range(0..input_domain)); - } - let ciphertext = - encrypt(&input_values, &secret_key, ¶ms, input_domain, num_packing, &mut prng)?; - - // Number of values packed into one polynomial. - let padded_length = ((1 << LOG_N) * num_packing) as usize; - - // Allocate more than enough space. - let output_values_length = padded_length * 2; - let mut output_values = vec![42; output_values_length]; - - // Decrypt into the buffer. The rest should be unused. - let n_written = - decrypt(&ciphertext, &secret_key, ¶ms, input_domain, num_packing, &mut output_values)?; + let packing_dimension = 3; + // Set num_packed_coeffs to be larger than the actual number of packed values. The packing + // function should pad with zeros to fill in the packed vector. + let num_packed_coeffs = + (num_input_values + packing_dimension - 1) / packing_dimension + 1 as usize; + let input_values: Vec = + (0..num_input_values).map(|_| rand::thread_rng().gen_range(0..input_domain)).collect(); + + // Encrypt the vector. + let plaintext = HashMap::from([(String::from(DEFAULT_ID), input_values.clone())]); + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { + base: input_domain as u64, + dimension: packing_dimension as u64, + num_packed_coeffs: num_packed_coeffs as u64, + }, + )]); + let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; + + // Decrypt and unpack the ciphertexts. + let decrypted = decrypt(&ciphertext, &secret_key, ¶ms, &packed_vector_configs)?; + let output_values = &decrypted[DEFAULT_ID]; // Check that message is correctly decrypted with right padding. - expect_that!(n_written, eq(padded_length)); - expect_that!(output_values[..num_messages], container_eq(input_values)); + let padded_length = (num_packed_coeffs * packing_dimension) as usize; + expect_that!(output_values.len(), eq(padded_length)); + expect_that!(output_values.len(), gt(num_input_values)); + expect_that!(output_values[..num_input_values], container_eq(input_values)); expect_that!( - output_values[num_messages..padded_length], - container_eq(vec![0; padded_length - num_messages]) - ); - expect_that!( - output_values[padded_length..output_values_length], - container_eq(vec![42; output_values_length - padded_length]) + output_values[num_input_values..], + container_eq(vec![0; padded_length - num_input_values]) ); Ok(()) @@ -107,6 +111,8 @@ fn encrypt_decrypt_padding() -> Result<()> { #[gtest] fn encrypt_decrypt_long() -> Result<()> { + const DEFAULT_ID: &str = "default"; + // Generate public parameters and secret key. let public_seed = single_thread_hkdf::generate_seed()?; let num_public_polynomials = 10; // Generate enough a's to pass long messages. @@ -114,53 +120,126 @@ fn encrypt_decrypt_long() -> Result<()> { let seed = single_thread_hkdf::generate_seed()?; let mut prng = single_thread_hkdf::create(&seed)?; let secret_key = generate_secret_key(¶ms, &mut prng)?; - let num_packing = 8; - - // Number of values packed into one polynomial. - let poly_capacity = ((1 << LOG_N) * num_packing) as usize; + let packing_dimension = 8 as usize; + let num_coeffs_per_poly = (1 << LOG_N) as usize; + // Number of values can be packed into one polynomial. + let poly_capacity = num_coeffs_per_poly * packing_dimension; // Generate a long random vector, encrypt and decrypt it. let input_domain = 2; - let num_messages = 3 * poly_capacity + 1; - - let mut input_values: Vec = Vec::with_capacity(num_messages); - for _ in 0..num_messages { - input_values.push(rand::thread_rng().gen_range(0..input_domain)); - } - let ciphertext = - encrypt(&input_values, &secret_key, ¶ms, input_domain, num_packing, &mut prng)?; - - // Allocate more than enough space. - let output_values_length = num_messages * 2; - let mut output_values = vec![42; output_values_length]; - - // Decrypt into the buffer. The rest should be unused. - let n_written = - decrypt(&ciphertext, &secret_key, ¶ms, input_domain, num_packing, &mut output_values)?; + let num_input_values = 3 * poly_capacity + 1; + let num_packed_coeffs = (num_input_values + packing_dimension - 1) / packing_dimension as usize; + let input_values: Vec = + (0..num_input_values).map(|_| rand::thread_rng().gen_range(0..input_domain)).collect(); + let plaintext = HashMap::from([(String::from(DEFAULT_ID), input_values.clone())]); + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { + base: input_domain as u64, + dimension: packing_dimension as u64, + num_packed_coeffs: num_packed_coeffs as u64, + }, + )]); + let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; + + let decrypted = decrypt(&ciphertext, &secret_key, ¶ms, &packed_vector_configs)?; + let output_values = &decrypted[DEFAULT_ID]; // Check that message is correctly decrypted with right padding. - let padded_length = 4 * poly_capacity; // Last polynomial is padded. - expect_that!(n_written, eq(padded_length)); - expect_that!(output_values[..num_messages], container_eq(input_values)); - expect_that!( - output_values[num_messages..padded_length], - container_eq(vec![0; padded_length - num_messages]) - ); + let padded_length = num_packed_coeffs * packing_dimension; + expect_that!(output_values.len(), eq(padded_length)); + expect_that!(output_values.len(), gt(num_input_values)); + expect_that!(output_values[..num_input_values], container_eq(input_values)); expect_that!( - output_values[padded_length..output_values_length], - container_eq(vec![42; output_values_length - padded_length]) + output_values[num_input_values..], + container_eq(vec![0; padded_length - num_input_values]) ); // If the input is too long, we should fail. - let num_messages = num_public_polynomials * poly_capacity + 1; - let mut input_values: Vec = Vec::with_capacity(num_messages); - for _ in 0..num_messages { - input_values.push(rand::thread_rng().gen_range(0..input_domain)); - } - match encrypt(&input_values, &secret_key, ¶ms, input_domain, num_packing, &mut prng) { + let num_values_too_long = num_public_polynomials * poly_capacity + 1; + let input_values_too_long: Vec = + (0..num_values_too_long).map(|_| rand::thread_rng().gen_range(0..input_domain)).collect(); + let plaintext_too_long = HashMap::from([(String::from(DEFAULT_ID), input_values_too_long)]); + match encrypt(&plaintext_too_long, &packed_vector_configs, &secret_key, ¶ms, &mut prng) { Err(e) => expect_that!(e, status_is(StatusErrorCode::InvalidArgument)), Ok(_) => fail!("Expected call to fail")?, } Ok(()) } + +#[gtest] +fn encrypt_decrypt_two_vectors() -> Result<()> { + const ID0: &str = "fst"; + const ID1: &str = "snd"; + + // Generate public parameters and secret key. + let public_seed = single_thread_hkdf::generate_seed()?; + let num_public_polynomials = 1; + let params = create_public_parameters(LOG_N, LOG_T, &QS, num_public_polynomials, &public_seed)?; + let seed = single_thread_hkdf::generate_seed()?; + let mut prng = single_thread_hkdf::create(&seed)?; + let secret_key = generate_secret_key(¶ms, &mut prng)?; + + // Specifications for the two input vectors. + let input_domains = [10, 8]; + let packing_dimensions = [2, 3]; + let num_input_values = [9, 13]; + // The number of packed coefficients for both vectors. + let num_packed_coeffs = [5, 5]; + + let packed_vector_configs = HashMap::from([ + ( + String::from(ID0), + PackedVectorConfig { + base: input_domains[0] as u64, + dimension: packing_dimensions[0] as u64, + num_packed_coeffs: num_packed_coeffs[0] as u64, + }, + ), + ( + String::from(ID1), + PackedVectorConfig { + base: input_domains[1] as u64, + dimension: packing_dimensions[1] as u64, + num_packed_coeffs: num_packed_coeffs[1] as u64, + }, + ), + ]); + + // The plaintext contains two vectors. + let input_values0: Vec = (0..num_input_values[0]) + .map(|_| rand::thread_rng().gen_range(0..input_domains[0])) + .collect(); + let input_values1: Vec = (0..num_input_values[1]) + .map(|_| rand::thread_rng().gen_range(0..input_domains[1])) + .collect(); + let plaintext = HashMap::from([ + (String::from(ID0), input_values0.clone()), + (String::from(ID1), input_values1.clone()), + ]); + let ciphertext = encrypt(&plaintext, &packed_vector_configs, &secret_key, ¶ms, &mut prng)?; + + // Decrypt and check the output contains the two vectors that are padded correctly. + let decrypted = decrypt(&ciphertext, &secret_key, ¶ms, &packed_vector_configs)?; + verify_that!(decrypted.contains_key(ID0), eq(true))?; + verify_that!(decrypted.contains_key(ID1), eq(true))?; + + let output_values0 = &decrypted[ID0]; + let output_values1 = &decrypted[ID1]; + expect_that!(output_values0.len(), eq(num_packed_coeffs[0] * packing_dimensions[0])); + expect_that!(output_values0.len(), gt(num_input_values[0])); + expect_that!(output_values0[..num_input_values[0]], container_eq(input_values0)); + expect_that!( + output_values0[num_input_values[0]..], + container_eq(vec![0; num_packed_coeffs[0] * packing_dimensions[0] - num_input_values[0]]) + ); + expect_that!(output_values1.len(), eq(num_packed_coeffs[1] * packing_dimensions[1])); + expect_that!(output_values1.len(), gt(num_input_values[1])); + expect_that!(output_values1[..num_input_values[1]], container_eq(input_values1)); + expect_that!( + output_values1[num_input_values[1]..], + container_eq(vec![0; num_packed_coeffs[1] * packing_dimensions[1] - num_input_values[1]]) + ); + Ok(()) +} diff --git a/shell_wrapper/shell_types.cc b/shell_wrapper/shell_types.cc index a348b30..4ff45d8 100644 --- a/shell_wrapper/shell_types.cc +++ b/shell_wrapper/shell_types.cc @@ -112,7 +112,8 @@ FfiStatus WriteSmallRnsPolynomialToBuffer(const RnsPolynomialWrapper* poly, } FfiStatus ReadSmallRnsPolynomialFromBuffer(const int64_t* buffer, - uint64_t buffer_len, uint64_t log_n, + uint64_t buffer_len, + uint64_t num_coeffs, ModuliWrapper moduli, RnsPolynomialWrapper* out) { if (buffer == nullptr || out == nullptr) { @@ -120,7 +121,6 @@ FfiStatus ReadSmallRnsPolynomialFromBuffer(const int64_t* buffer, secure_aggregation::kNullPointerErrorMessage)); } - int num_coeffs = 1 << log_n; if (buffer_len > num_coeffs) { return MakeFfiStatus( absl::InvalidArgumentError("Buffer has too many coefficients, it does " diff --git a/shell_wrapper/shell_types.h b/shell_wrapper/shell_types.h index 93dfe66..a840b3f 100644 --- a/shell_wrapper/shell_types.h +++ b/shell_wrapper/shell_types.h @@ -90,7 +90,8 @@ FfiStatus WriteSmallRnsPolynomialToBuffer(const RnsPolynomialWrapper* poly, // a RnsPolynomialWrapper containing the polynomial in RNS coefficient form to // `out`. FfiStatus ReadSmallRnsPolynomialFromBuffer(const int64_t* buffer, - uint64_t buffer_len, uint64_t log_n, + uint64_t buffer_len, + uint64_t num_coeffs, ModuliWrapper moduli, RnsPolynomialWrapper* out); diff --git a/shell_wrapper/shell_types_test.cc b/shell_wrapper/shell_types_test.cc index fed0446..0f4939c 100644 --- a/shell_wrapper/shell_types_test.cc +++ b/shell_wrapper/shell_types_test.cc @@ -235,7 +235,7 @@ TEST(ShellTypesTest, ReadWriteSmallRnsPolynomialToBufferKahe) { } RnsPolynomialWrapper poly{nullptr}; - status = ReadSmallRnsPolynomialFromBuffer(buffer, buffer_len, kLogN, + status = ReadSmallRnsPolynomialFromBuffer(buffer, buffer_len, 1 << kLogN, moduli_wrapper, &poly); SECAGG_EXPECT_OK(UnwrapFfiStatus(status)); EXPECT_NE(poly.ptr, nullptr); @@ -321,9 +321,9 @@ TEST(ShellTypesTest, ReadWriteErrors) { } // Try to read from the buffer. - status = - ReadSmallRnsPolynomialFromBuffer(long_input_buffer, long_input_buffer_len, - kLogN, moduli_wrapper, &poly_wrapper); + status = ReadSmallRnsPolynomialFromBuffer(long_input_buffer, + long_input_buffer_len, 1 << kLogN, + moduli_wrapper, &poly_wrapper); // We should get an error. EXPECT_THAT(UnwrapFfiStatus(status), StatusIs(absl::StatusCode::kInvalidArgument, diff --git a/shell_wrapper/single_thread_hkdf.h b/shell_wrapper/single_thread_hkdf.h index 75b3f49..305bdfd 100644 --- a/shell_wrapper/single_thread_hkdf.h +++ b/shell_wrapper/single_thread_hkdf.h @@ -35,8 +35,6 @@ FfiStatus CreateSingleThreadHkdf(rust::Slice seed, SingleThreadHkdfWrapper& out); FfiStatus Rand8(SingleThreadHkdfWrapper& prng, uint8_t& out); -size_t SingleThreadHkdfSeedLength(); - // FFI-friendly wrapper around crypto::tink::subtle::ComputeHkdf, with fixed // hash function SHA256. FfiStatus ComputeHkdfWrapper(rust::Slice input, @@ -44,4 +42,8 @@ FfiStatus ComputeHkdfWrapper(rust::Slice input, rust::Slice info, size_t out_len, std::unique_ptr& out); +extern "C" { +size_t SingleThreadHkdfSeedLength(); +} + #endif // SECURE_AGGREGATION_WILLOW_SRC_PRNG_SINGLE_THREAD_HKDF_WRAPPER_H_ diff --git a/willow/benches/BUILD b/willow/benches/BUILD index 4c0104c..7c18ba0 100644 --- a/willow/benches/BUILD +++ b/willow/benches/BUILD @@ -27,7 +27,9 @@ rust_library( srcs = ["shell_benchmarks.rs"], deps = [ "@crate_index//:clap", + "//willow/src/api:willow_api_common", "//willow/src/shell:kahe_shell", + "//willow/src/shell:shell_parameters_generation", "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", "//willow/src/testing_utils", diff --git a/willow/benches/shell_benchmarks.rs b/willow/benches/shell_benchmarks.rs index 431db63..12b7c66 100644 --- a/willow/benches/shell_benchmarks.rs +++ b/willow/benches/shell_benchmarks.rs @@ -13,20 +13,23 @@ // limitations under the License. use clap::Parser; +use std::collections::HashMap; use std::hint::black_box; use std::time::Duration; use client_traits::SecureAggregationClient; use decryptor_traits::SecureAggregationDecryptor; -use kahe_shell::{ShellKahe, ShellKaheConfig}; +use kahe_shell::ShellKahe; use kahe_traits::KaheBase; use prng_traits::SecurePrng; use server_traits::SecureAggregationServer; -use shell_testing_parameters::{make_ahe_config, make_kahe_rns_config}; +use shell_parameters_generation::generate_packing_config; +use shell_testing_parameters::{make_ahe_config, make_kahe_config_for}; use single_thread_hkdf::SingleThreadHkdfPrng; use testing_utils::{generate_random_unsigned_vector, ShellClient, ShellClientMessage}; use vahe_shell::ShellVahe; use verifier_traits::SecureAggregationVerifier; +use willow_api_common::AggregationConfig; use willow_v1_client::WillowV1Client; use willow_v1_common::{ CiphertextContribution, DecryptionRequestContribution, DecryptorPublicKey, @@ -36,6 +39,8 @@ use willow_v1_decryptor::{DecryptorState, WillowV1Decryptor}; use willow_v1_server::{ServerState, WillowV1Server}; use willow_v1_verifier::{VerifierState, WillowV1Verifier}; +const DEFAULT_ID: &str = "default"; + #[derive(Parser, Debug)] #[command(version, about, long_about = None)] pub struct Args { @@ -116,17 +121,22 @@ struct BaseInputs { fn setup_base(args: &Args) -> BaseInputs { // Create common configs and seeds. Prepare enough public polynomials to // accomodate the input length. - let kahe_rns_config = make_kahe_rns_config(args.plaintext_modulus_bits).unwrap(); - let num_coeffs = 1 << kahe_rns_config.log_n; - let num_public_polynomials = (args.input_length as f64 / num_coeffs as f64).ceil() as usize; - let kahe_config = ShellKaheConfig::new( - args.input_domain, - args.max_num_clients, - args.num_packing, - num_public_polynomials, - kahe_rns_config.clone(), - ) - .unwrap(); + let default_id = String::from(DEFAULT_ID); + let aggregation_config = AggregationConfig { + vector_lengths_and_bounds: HashMap::from([( + default_id.clone(), + (args.input_length as isize, args.input_domain as i64), + )]), + max_number_of_decryptors: 1, + max_number_of_clients: args.max_num_clients as i64, + max_decryptor_dropouts: 0, + session_id: String::from("benchmark"), + willow_version: (1, 0), + }; + let packed_vector_configs = + generate_packing_config(args.plaintext_modulus_bits, &aggregation_config).unwrap(); + let kahe_config = + make_kahe_config_for(args.plaintext_modulus_bits, packed_vector_configs).unwrap(); let ahe_config = make_ahe_config(); let public_kahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let public_ahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); @@ -197,8 +207,9 @@ struct ClientInputs { fn setup_client(args: &Args) -> ClientInputs { let inputs = setup_base(args); - let plaintext = generate_random_unsigned_vector(args.input_length, args.input_domain); - ClientInputs { client: inputs.client, public_key: inputs.public_key, plaintext } + let input_values = generate_random_unsigned_vector(args.input_length, args.input_domain); + let plaintext = HashMap::from([(String::from(DEFAULT_ID), input_values)]); + ClientInputs { client: inputs.client, public_key: inputs.public_key, plaintext: plaintext } } fn run_client(inputs: &mut ClientInputs) { @@ -227,8 +238,9 @@ fn setup_verifier_verify_client_message(args: &Args) -> VerifierInputs { let mut decryption_request_contributions = vec![]; for _ in 0..args.n_iterations { // Generates a plaintext and encrypts. - let client_plaintext = + let client_input_values = generate_random_unsigned_vector(args.input_length, args.input_domain); + let client_plaintext = HashMap::from([(String::from(DEFAULT_ID), client_input_values)]); let client_message = inputs.client.create_client_message(&client_plaintext, &inputs.public_key).unwrap(); let (_, decryption_request_contribution) = @@ -260,8 +272,9 @@ fn setup_server_handle_client_message(args: &Args) -> ServerInputs { let mut ciphertext_contributions = vec![]; for _ in 0..args.n_iterations { // Generates a plaintext and encrypts. - let client_plaintext = + let client_input_values = generate_random_unsigned_vector(args.input_length, args.input_domain); + let client_plaintext = HashMap::from([(String::from(DEFAULT_ID), client_input_values)]); let client_message = inputs.client.create_client_message(&client_plaintext, &inputs.public_key).unwrap(); let (ciphertext_contribution, _) = @@ -297,7 +310,8 @@ fn setup_server_recover_aggregation_result(args: &Args) -> ServerRecoverInputs { let mut inputs = setup_base(args); // Client generates a plaintext and encrypts. - let client_plaintext = generate_random_unsigned_vector(args.input_length, args.input_domain); + let client_input_values = generate_random_unsigned_vector(args.input_length, args.input_domain); + let client_plaintext = HashMap::from([(String::from(DEFAULT_ID), client_input_values)]); let client_message = inputs.client.create_client_message(&client_plaintext, &inputs.public_key).unwrap(); @@ -347,8 +361,9 @@ fn setup_decryptor_partial_decryption(args: &Args) -> DecryptorInputs { let mut inputs = setup_base(args); for _ in 0..args.max_num_clients { // Generates a plaintext and encrypts. - let client_plaintext = + let client_input_values = generate_random_unsigned_vector(args.input_length, args.input_domain); + let client_plaintext = HashMap::from([(String::from(DEFAULT_ID), client_input_values)]); let client_message = inputs.client.create_client_message(&client_plaintext, &inputs.public_key).unwrap(); diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index 2389ae3..fbed395 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -30,3 +30,13 @@ cc_proto_library( name = "decryptor_cc_proto", deps = [":decryptor_proto"], ) + +proto_library( + name = "key_proto", + srcs = ["key.proto"], +) + +cc_proto_library( + name = "key_cc_proto", + deps = [":key_proto"], +) diff --git a/willow/proto/willow/decryptor.proto b/willow/proto/willow/decryptor.proto index e13f243..8766de8 100644 --- a/willow/proto/willow/decryptor.proto +++ b/willow/proto/willow/decryptor.proto @@ -1,4 +1,4 @@ -// Copyright 2021 Google LLC +// Copyright 2025 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/willow/proto/willow/key.proto b/willow/proto/willow/key.proto new file mode 100644 index 0000000..7efab0d --- /dev/null +++ b/willow/proto/willow/key.proto @@ -0,0 +1,31 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package secure_aggregation.willow; + +option java_multiple_files = true; + +// A simple container for a Willow public cryptographic key that embeds the key +// ID and the "raw" key material. +message Key { + // A short identifier for the key to allow decryption or verification to find + // the corresponding key from a list. + bytes key_id = 1; + + // The raw key material. This is the serialized bytes of the key material + // returned by the key generation service. + bytes key_material = 2; +} diff --git a/willow/src/shell/BUILD b/willow/src/shell/BUILD index 7a62036..eb679dd 100644 --- a/willow/src/shell/BUILD +++ b/willow/src/shell/BUILD @@ -77,6 +77,7 @@ rust_test( deps = [ ":kahe_shell", "@crate_index//:googletest", + "//shell_wrapper:kahe", "//shell_wrapper:status", "//willow/src/testing_utils", "//willow/src/testing_utils:shell_testing_parameters", @@ -91,7 +92,11 @@ rust_library( deps = [ ":ahe_shell", ":kahe_shell", + ":shell_parameters_generation", + "@protobuf//rust:protobuf", + "//shell_wrapper:kahe", "//shell_wrapper:status", + "//willow/src/api:willow_api_common", ], ) @@ -111,14 +116,39 @@ rust_proto_library( deps = [":shell_parameters_proto"], ) +# Parameters utilities (e.g. conversion between rust structs and protos) +rust_library( + name = "shell_parameters_utils", + srcs = ["parameters_utils.rs"], + crate_root = "parameters_utils.rs", + deps = [ + ":kahe_shell", + ":shell_parameters_rust_proto", + "@protobuf//rust:protobuf", + "//shell_wrapper:kahe", + "//shell_wrapper:status", + ], +) + +rust_test( + name = "shell_parameters_utils_test", + crate = ":shell_parameters_utils", + deps = [ + ":shell_parameters_utils", + "@crate_index//:googletest", + "@protobuf//rust:protobuf_gtest_matchers", + "//shell_wrapper:status", + ], +) + # Parameter generation rust_library( name = "shell_parameters_generation", srcs = ["parameters_generation.rs"], crate_root = "parameters_generation.rs", deps = [ - ":shell_parameters_rust_proto", - "@protobuf//rust:protobuf", + ":kahe_shell", + "//shell_wrapper:kahe", "//shell_wrapper:status", "//willow/src/api:willow_api_common", ], diff --git a/willow/src/shell/kahe.rs b/willow/src/shell/kahe.rs index 202042b..cacbb92 100644 --- a/willow/src/shell/kahe.rs +++ b/willow/src/shell/kahe.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use kahe::KahePublicParametersWrapper; +use kahe::{KahePublicParametersWrapper, PackedVectorConfig}; use kahe_traits::{ KaheBase, KaheDecrypt, KaheEncrypt, KaheKeygen, TrySecretKeyFrom, TrySecretKeyInto, }; @@ -21,106 +21,93 @@ use shell_types::{ write_small_rns_polynomial_to_buffer, RnsPolynomial, RnsPolynomialVec, }; use single_thread_hkdf::{Seed, SingleThreadHkdfPrng}; +use std::collections::HashMap; /// Number of bits supported by the C++ big integer type used for KAHE /// plaintext. -const BIG_INT_BITS: u64 = 256; - -/// Stores parameters to create a new RNS context for KAHE. -#[derive(Debug, Clone)] -pub struct KaheRnsConfig { - pub log_n: u64, - pub log_t: u64, - pub qs: Vec, -} +const BIG_INT_BITS: usize = 256; -/// ShellKahe configuration. For a fixed RNS context, we can have multiple -/// values for the other parameters (e.g. short or long inputs, or different -/// combinations of packing_base/num_packing that fit within the same plaintext -/// modulus). Can only be created with valid parameters from outside this crate. -#[derive(Debug, Clone)] +#[derive(Debug, PartialEq, Clone)] pub struct ShellKaheConfig { - pub input_domain: u64, - pub max_num_clients: usize, + pub log_n: usize, + pub moduli: Vec, + pub log_t: usize, pub num_public_polynomials: usize, - pub rns_config: KaheRnsConfig, - pub(crate) packing_base: u64, - pub(crate) num_packing: usize, -} - -impl ShellKaheConfig { - /// Validates parameters and creates a new ShellKaheConfig instance. - pub fn new( - input_domain: u64, - max_num_clients: usize, - num_packing: usize, - num_public_polynomials: usize, - rns_config: KaheRnsConfig, - ) -> Result { - if num_packing == 0 { - return Err(status::invalid_argument("num_packing must be > 0")); - } - // B = n * t - let packing_base = input_domain * max_num_clients as u64; - if packing_base <= 1 { - return Err(status::invalid_argument("packing_base must be > 1")); - } - if rns_config.log_t > BIG_INT_BITS { - return Err(status::invalid_argument(format!( - "log_t must be <= {} for plaintexts to fit in the C++ big integer type, got {}", - BIG_INT_BITS, rns_config.log_t - ))); - } - let log_packing_base = (packing_base as f64).log2().ceil() as u64; - if (num_packing as u64) * log_packing_base > rns_config.log_t { - return Err(status::invalid_argument(format!( - "packing_base^num_packing must not be larger than the KAHE plaintext modulus 2^log_t: packing_base = {}, num_packing = {}, log_t = {}", packing_base, num_packing, rns_config.log_t - ))); - } - Ok(Self { - input_domain, - max_num_clients, - num_public_polynomials, - rns_config, - packing_base, - num_packing, - }) - } + pub packed_vector_configs: HashMap, } /// Base type holding public KAHE configuration and C++ parameters. pub struct ShellKahe { - input_domain: u64, - packing_base: u64, - num_packing: usize, - rns_config: KaheRnsConfig, + /// Parameters used to initialize ShellKahe. + config: ShellKaheConfig, + + /// Number of coefficients in a KAHE polynomial. + num_coeffs: usize, + + /// The KAHE public parameters implemented in C++, including the public polynomials and + /// the parameters to instantiate the KAHE scheme. public_kahe_parameters: KahePublicParametersWrapper, } impl ShellKahe { - /// Creates a new ShellKahe instance. - pub fn new(config: ShellKaheConfig, public_seed: &Seed) -> Result { + pub fn new( + shell_kahe_config: ShellKaheConfig, + public_seed: &Seed, + ) -> Result { + Self::validate_kahe_config(&shell_kahe_config)?; + let num_coeffs = 1 << shell_kahe_config.log_n; let public_kahe_parameters = kahe::create_public_parameters( - config.rns_config.log_n, - config.rns_config.log_t, - &config.rns_config.qs, - config.num_public_polynomials, + shell_kahe_config.log_n as u64, + shell_kahe_config.log_t as u64, + &shell_kahe_config.moduli, + shell_kahe_config.num_public_polynomials, &public_seed, )?; - Ok(Self { - input_domain: config.input_domain, - packing_base: config.packing_base, - num_packing: config.num_packing, - rns_config: config.rns_config, - public_kahe_parameters, - }) + Ok(Self { config: shell_kahe_config, num_coeffs, public_kahe_parameters }) + } + + /// Validates KAHE parameters in ShellKaheConfig. + fn validate_kahe_config(config: &ShellKaheConfig) -> Result<(), status::StatusError> { + if config.log_t > BIG_INT_BITS { + return Err(status::invalid_argument(format!( + "log_t must be <= {} for plaintexts to fit in the C++ big integer type, got {}", + BIG_INT_BITS, config.log_t + ))); + } + for (id, packed_vector_config) in config.packed_vector_configs.iter() { + let base = packed_vector_config.base; + let dimension = packed_vector_config.dimension; + let num_packed_coeffs = packed_vector_config.num_packed_coeffs; + if base <= 1 { + return Err(status::invalid_argument(format!("base must be > 1, got {}", base))); + } + if dimension <= 0 { + return Err(status::invalid_argument(format!( + "For packing id {}, dimension must be > 0, got {}", + id, dimension + ))); + } + if num_packed_coeffs <= 0 { + return Err(status::invalid_argument(format!( + "For packing id {}, num_packed_coeffs must be > 0, got {}", + id, num_packed_coeffs + ))); + } + let log_base = (base as f64).log2().ceil() as u64; + if log_base * dimension > config.log_t as u64 { + return Err(status::invalid_argument(format!( + "For packing id {}, base^dimension must not be larger than the KAHE plaintext modulus 2^log_t+1: base = {}, dimension = {}, log_t = {}", id, base, dimension, config.log_t + ))); + } + } + Ok(()) } } impl KaheBase for ShellKahe { type SecretKey = RnsPolynomial; - type Plaintext = Vec; + type Plaintext = HashMap>; type Ciphertext = RnsPolynomialVec; @@ -150,8 +137,22 @@ impl KaheBase for ShellKahe { right.len() ))); } - for (i, v) in left.iter().enumerate() { - right[i] += v; + for (id, values) in left.iter() { + if let Some(right_values) = right.get_mut(id) { + if right_values.len() != values.len() { + return Err(status::invalid_argument(format!( + "right values for key {} must have the same length as left, got {} and {}", + id, + right_values.len(), + values.len() + ))); + } + for (i, v) in values.iter().enumerate() { + right_values[i] += v; + } + } else { + return Err(status::invalid_argument(format!("right must contain key {}", id))); + } } Ok(()) } @@ -181,21 +182,38 @@ impl KaheEncrypt for ShellKahe { r: &mut Self::Rng, ) -> Result { // Check that inputs are valid to avoid packing and plaintext overflow errors. - for v in pt.iter() { - if *v >= self.input_domain { - return Err(status::invalid_argument(format!( - "plaintext value {} is larger than the input domain {}", - *v, self.input_domain - ))); + for (id, values) in pt.iter() { + if let Some(packed_vector_config) = self.config.packed_vector_configs.get(id) { + let max_length = + packed_vector_config.dimension * packed_vector_config.num_packed_coeffs; + if values.len() > max_length as usize { + return Err(status::invalid_argument(format!( + "plaintext for id {} can have at most {} elements, got {}", + id, + max_length, + values.len() + ))); + } + for v in values.iter() { + if *v >= packed_vector_config.base { + return Err(status::invalid_argument(format!( + "plaintext for id {} cannot contain values larger than the input bound {}, got {}", + id, + packed_vector_config.base, + *v, + ))); + } + } + } else { + return Err(status::invalid_argument(format!("unknown plaintext id {}", id))); } } kahe::encrypt( - &pt[..], + &pt, + &self.config.packed_vector_configs, &sk, &self.public_kahe_parameters, - self.packing_base, - self.num_packing, &mut r.0, ) } @@ -207,42 +225,19 @@ impl KaheDecrypt for ShellKahe { ct: &Self::Ciphertext, sk: &Self::SecretKey, ) -> Result { - // Allocate the right number of values to hold an unpacked and padded output. - let num_coeffs = 1 << self.rns_config.log_n; - let num_values = num_coeffs * self.num_packing * (ct.len as usize); - let mut output_values = vec![0; num_values]; - - // Decrypt into the buffer. - let n_written = kahe::decrypt( - &ct, - &sk, - &self.public_kahe_parameters, - self.packing_base, - self.num_packing, - &mut output_values[..], - )?; - - if n_written != num_values { - return Err(status::internal(format!( - "Expected {} decrypted values, but got {}.", - num_values, n_written - ))); - } - Ok(output_values) + kahe::decrypt(&ct, &sk, &self.public_kahe_parameters, &self.config.packed_vector_configs) } } impl TrySecretKeyInto> for ShellKahe { fn try_secret_key_into(&self, sk: Self::SecretKey) -> Result, status::StatusError> { - let num_coeffs = 1 << self.rns_config.log_n; - let mut signed_values: Vec = vec![0; num_coeffs]; + let mut signed_values: Vec = vec![0; self.num_coeffs]; let moduli = kahe::get_moduli(&self.public_kahe_parameters); - let n_written = write_small_rns_polynomial_to_buffer(&sk, &moduli, &mut signed_values[..])?; - if n_written != num_coeffs { + if n_written != self.num_coeffs { return Err(status::internal(format!( "Expected {} coefficients, but got {}.", - num_coeffs, n_written + self.num_coeffs, n_written ))); } @@ -255,20 +250,18 @@ impl TrySecretKeyFrom> for ShellKahe { &self, sk_buffer: Vec, ) -> Result { - let log_n = self.rns_config.log_n as usize; - if sk_buffer.len() < log_n { + if sk_buffer.len() < self.num_coeffs { return Err(status::invalid_argument(format!( "secret key buffer is too short: {} < {}", sk_buffer.len(), - self.rns_config.log_n + self.num_coeffs ))); } let moduli = kahe::get_moduli(&self.public_kahe_parameters); - let num_coeffs = 1 << log_n; let poly = read_small_rns_polynomial_from_buffer( - &sk_buffer[..num_coeffs], // Remove potential padding from AHE decryption. - self.rns_config.log_n, + &sk_buffer[..self.num_coeffs], // Remove potential padding from AHE decryption. + self.num_coeffs as u64, &moduli, )?; Ok(poly) @@ -278,7 +271,17 @@ impl TrySecretKeyFrom> for ShellKahe { #[cfg(test)] mod test { // Instead of `super::*` because we consume types from other testing crates. + use googletest::{gtest, verify_eq, verify_le}; + use kahe::PackedVectorConfig; use kahe_shell::*; + use kahe_traits::{ + KaheBase, KaheDecrypt, KaheEncrypt, KaheKeygen, TrySecretKeyFrom, TrySecretKeyInto, + }; + use prng_traits::SecurePrng; + use shell_testing_parameters::{make_kahe_config_for, set_kahe_num_public_polynomials}; + use single_thread_hkdf::SingleThreadHkdfPrng; + use std::collections::HashMap; + use testing_utils::generate_random_unsigned_vector; /// Standard deviation of the discrete Gaussian distribution used for /// secret key generation. Hardcoded in shell_wrapper/kahe.h for now (if we ever @@ -293,61 +296,41 @@ mod test { /// Tail bound for the case of a single secret key. const TAIL_BOUND: i64 = (TAIL_BOUND_MULTIPLIER * SECRET_KEY_STD + 1.0) as i64; - use googletest::{gtest, verify_eq, verify_le, verify_that}; - use kahe_traits::{ - KaheBase, KaheDecrypt, KaheEncrypt, KaheKeygen, TrySecretKeyFrom, TrySecretKeyInto, - }; - use prng_traits::SecurePrng; - use shell_testing_parameters::make_kahe_rns_config; - use single_thread_hkdf::SingleThreadHkdfPrng; - use testing_utils::generate_random_unsigned_vector; + /// Default ID used in tests. + const DEFAULT_ID: &str = "default"; #[gtest] fn test_encrypt_decrypt_short() -> googletest::Result<()> { let plaintext_modulus_bits = 39; - let input_domain = 10; - let max_num_clients = 100; - let num_packing = 2; - let num_public_polynomials = 1; - let rns_config = make_kahe_rns_config(plaintext_modulus_bits)?; - let kahe_config = ShellKaheConfig::new( - input_domain, - max_num_clients, - num_packing, - num_public_polynomials, - rns_config.clone(), - )?; + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 }, + )]); + let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let public_seed = SingleThreadHkdfPrng::generate_seed()?; let kahe = ShellKahe::new(kahe_config, &public_seed)?; - let pt = vec![1, 2, 3, 4, 5, 6, 7, 8]; + let pt = HashMap::from([(String::from(DEFAULT_ID), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]); let seed = SingleThreadHkdfPrng::generate_seed()?; let mut prng = SingleThreadHkdfPrng::create(&seed)?; let sk = kahe.key_gen(&mut prng)?; let ct = kahe.encrypt(&pt, &sk, &mut prng)?; let decrypted = kahe.decrypt(&ct, &sk)?; - verify_eq!(&pt, &decrypted[..pt.len()]) + verify_eq!(&pt, &decrypted) } #[gtest] fn test_encrypt_decrypt_with_serialized_key() -> googletest::Result<()> { let plaintext_modulus_bits = 39; - let input_domain = 10; - let max_num_clients = 100; - let num_packing = 2; - let num_public_polynomials = 1; - let rns_config = make_kahe_rns_config(plaintext_modulus_bits)?; - let kahe_config = ShellKaheConfig::new( - input_domain, - max_num_clients, - num_packing, - num_public_polynomials, - rns_config.clone(), - )?; + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { base: 10, dimension: 2, num_packed_coeffs: 5 }, + )]); + let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let public_seed = SingleThreadHkdfPrng::generate_seed()?; let kahe = ShellKahe::new(kahe_config, &public_seed)?; - let pt = vec![1, 2, 3, 4, 5, 6, 7, 8]; + let pt = HashMap::from([(String::from(DEFAULT_ID), vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9])]); let seed = SingleThreadHkdfPrng::generate_seed()?; let mut prng = SingleThreadHkdfPrng::create(&seed)?; let sk = kahe.key_gen(&mut prng)?; @@ -359,24 +342,28 @@ mod test { // Check that the decrypted value is the same as the original plaintext. let decrypted = kahe.decrypt(&ct, &sk_recovered)?; - verify_eq!(&pt, &decrypted[..pt.len()]) + verify_eq!(&pt, &decrypted) } #[gtest] fn test_encrypt_decrypt_long() -> googletest::Result<()> { let plaintext_modulus_bits = 17; let input_domain = 5; - let max_num_clients = 1000; - let num_packing = 1; - let num_public_polynomials = 2; - let rns_config = make_kahe_rns_config(plaintext_modulus_bits)?; - let kahe_config = ShellKaheConfig::new( - input_domain, - max_num_clients, - num_packing, - num_public_polynomials, - rns_config.clone(), - )?; + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { + base: input_domain, + dimension: 1, + num_packed_coeffs: 0, // Dummy value until we compute it from kahe_config. + }, + )]); + let mut kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; + // Set the number of packed coefficients to 2x the KAHE ring degree. + let num_messages = (1 << kahe_config.log_n) * 2; // Needs two polynomials. + let packed_vector_config = kahe_config.packed_vector_configs.get_mut(DEFAULT_ID).unwrap(); + packed_vector_config.num_packed_coeffs = num_messages; + set_kahe_num_public_polynomials(&mut kahe_config); + let public_seed = SingleThreadHkdfPrng::generate_seed()?; let kahe = ShellKahe::new(kahe_config, &public_seed)?; @@ -385,8 +372,10 @@ mod test { let sk = kahe.key_gen(&mut prng)?; // Generate a random vector, encrypt and decrypt it. - let num_messages = (1 << rns_config.log_n) * 2; // Needs two polynomials. - let pt = generate_random_unsigned_vector(num_messages, input_domain); + let pt = HashMap::from([( + String::from(DEFAULT_ID), + generate_random_unsigned_vector(num_messages as usize, input_domain as u64), + )]); let ct = kahe.encrypt(&pt, &sk, &mut prng)?; let decrypted = kahe.decrypt(&ct, &sk)?; verify_eq!(pt, decrypted) // Both vectors are padded to the same length. @@ -397,32 +386,36 @@ mod test { fn add_two_inputs() -> googletest::Result<()> { let plaintext_modulus_bits = 93; let input_domain = 10; - let max_num_clients = 2; - let num_packing = 1; - let num_public_polynomials = 2; - let rns_config = make_kahe_rns_config(plaintext_modulus_bits)?; - let kahe_config = ShellKaheConfig::new( - input_domain, - max_num_clients, - num_packing, - num_public_polynomials, - rns_config.clone(), - )?; + let num_messages = 50; + let packed_vector_configs = HashMap::from([( + String::from(DEFAULT_ID), + PackedVectorConfig { + base: input_domain * 2, + dimension: 1, + num_packed_coeffs: num_messages, + }, + )]); + let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; + let public_seed = SingleThreadHkdfPrng::generate_seed()?; let kahe = ShellKahe::new(kahe_config, &public_seed)?; let seed = SingleThreadHkdfPrng::generate_seed()?; let mut prng = SingleThreadHkdfPrng::create(&seed)?; - let num_messages = 50; - // Client 1 let sk1 = kahe.key_gen(&mut prng)?; - let pt1 = generate_random_unsigned_vector(num_messages, input_domain); + let pt1 = HashMap::from([( + String::from(DEFAULT_ID), + generate_random_unsigned_vector(num_messages as usize, input_domain as u64), + )]); let ct1 = kahe.encrypt(&pt1, &sk1, &mut prng)?; // Client 2 let mut sk2 = kahe.key_gen(&mut prng)?; - let mut pt2 = generate_random_unsigned_vector(num_messages, input_domain); + let mut pt2 = HashMap::from([( + String::from(DEFAULT_ID), + generate_random_unsigned_vector(num_messages as usize, input_domain as u64), + )]); let mut ct2 = kahe.encrypt(&pt2, &sk2, &mut prng)?; // Decryptor adds up keys @@ -432,24 +425,15 @@ mod test { kahe.add_ciphertexts_in_place(&ct1, &mut ct2)?; let pt_sum = kahe.decrypt(&ct2, &sk2)?; kahe.add_plaintexts_in_place(&pt1, &mut pt2)?; - verify_eq!(&pt2, &pt_sum[..num_messages]) + verify_eq!(&pt2, &pt_sum) } #[gtest] fn read_write_secret_key() -> googletest::Result<()> { let plaintext_modulus_bits = 17; - let input_domain = 2; - let max_num_clients = 100; - let num_packing = 2; - let num_public_polynomials = 1; - let rns_config = make_kahe_rns_config(plaintext_modulus_bits)?; - let kahe_config = ShellKaheConfig::new( - input_domain, - max_num_clients, - num_packing, - num_public_polynomials, - rns_config.clone(), - )?; + let packed_vector_configs = HashMap::from([]); + let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; + let public_seed = SingleThreadHkdfPrng::generate_seed()?; let kahe = ShellKahe::new(kahe_config, &public_seed)?; let seed = SingleThreadHkdfPrng::generate_seed()?; @@ -491,22 +475,12 @@ mod test { fn test_key_serialization_is_homomorphic() -> googletest::Result<()> { // Set up a ShellKahe instance. let plaintext_modulus_bits = 39; - let input_domain = 10; - let max_num_clients = 100; - let num_packing = 2; - let num_public_polynomials = 1; - let rns_config = make_kahe_rns_config(plaintext_modulus_bits)?; - let kahe_config = ShellKaheConfig::new( - input_domain, - max_num_clients, - num_packing, - num_public_polynomials, - rns_config.clone(), - )?; + let packed_vector_configs = HashMap::from([]); + let kahe_config = make_kahe_config_for(plaintext_modulus_bits, packed_vector_configs)?; let public_seed = SingleThreadHkdfPrng::generate_seed()?; let kahe = ShellKahe::new(kahe_config, &public_seed)?; - let pt = vec![1, 2, 3, 4, 5, 6, 7, 8]; + // The seed used to sample the secret keys. let seed = SingleThreadHkdfPrng::generate_seed()?; // Generate two keys, write them to buffers then add the buffers together. diff --git a/willow/src/shell/parameters.proto b/willow/src/shell/parameters.proto index 1e2b6b8..69ab514 100644 --- a/willow/src/shell/parameters.proto +++ b/willow/src/shell/parameters.proto @@ -14,25 +14,21 @@ edition = "2023"; -package secure_aggregation_willow; +package secure_aggregation.willow; // This proto defines how to pack an input vector into a KAHE plaintext. -// An input vector is split into `num_packed_coeffs` many sub-vectors of length -// `dimension` each. Each sub-vector is then packed into a single plaintext -// coefficient using base `base` encoding to allow summation over all clients' -// contributions. -message PackedVectorConfig { +// An input vector is split into `num_packed_coeffs` many sub-vectors of +// length `dimension` each. Each sub-vector is then packed into a single +// plaintext coefficient using base `base` encoding to allow summation over +// all clients' contributions. +message PackedVectorConfigProto { int64 base = 1; int64 dimension = 2; int64 num_packed_coeffs = 3; } -message ShellKahePackingConfig { - map packed_vectors = 1; -} - // This proto defines the parameters for instantiating the KAHE scheme -message ShellKaheConfig { +message ShellKaheConfigProto { // The first two fields define the KAHE ciphertext ring Z[X]/(q, X^N+1), // where `log_n` is log2(N) and `moduli` is the list of prime factors of q. int64 log_n = 1; @@ -46,5 +42,5 @@ message ShellKaheConfig { int64 num_public_polynomials = 4; // Configures how input vectors are packed into the KAHE plaintext. - ShellKahePackingConfig packing_config = 5; + map packed_vectors = 5; } diff --git a/willow/src/shell/parameters.rs b/willow/src/shell/parameters.rs index 5f42924..500563f 100644 --- a/willow/src/shell/parameters.rs +++ b/willow/src/shell/parameters.rs @@ -13,7 +13,9 @@ // limitations under the License. use ahe_shell::ShellAheConfig; -use kahe_shell::{KaheRnsConfig, ShellKaheConfig}; +use kahe_shell::ShellKaheConfig; +use shell_parameters_generation::{divide_and_roundup, generate_packing_config}; +use willow_api_common::AggregationConfig; /// This file contains parameters for the KAHE and AHE schemes in Willow, which /// are selected to have at least 128 bits of computational security and 40 bits @@ -35,14 +37,12 @@ use kahe_shell::{KaheRnsConfig, ShellKaheConfig}; /// - input of length 1K with 32-bit domain /// - max number of clients 10M /// - max number of decryptors 100 -const KAHE_LOG_N_1K_10M: u64 = 12; -const KAHE_LOG_T_1K_10M: u64 = 56; +const KAHE_LOG_N_1K_10M: usize = 12; +const KAHE_LOG_T_1K_10M: usize = 56; const KAHE_QS_1K_10M: [u64; 2] = [ 274877816833, // 38 bits 274877718529, // 38 bits ]; -const KAHE_NUM_PACKING_1K_10M: usize = 1; -const KAHE_NUM_PUBLIC_POLY_1K_10M: usize = 1; const AHE_LOG_N_1K_10M: u64 = 12; const AHE_T_1K_10M: u64 = 109965; const AHE_QS_1K_10M: [u64; 2] = [1099510824961, 1099508760577]; // 80 bits total @@ -52,16 +52,14 @@ const AHE_S_FLOOD_1K_10M: f64 = 3.0834e+16; /// - input of length 100K with 32-bit domain /// - max number of clients 10M /// - max number of decryptors 100 -const KAHE_LOG_N_100K_10M: u64 = 13; -const KAHE_LOG_T_100K_10M: u64 = 168; +const KAHE_LOG_N_100K_10M: usize = 13; +const KAHE_LOG_T_100K_10M: usize = 168; const KAHE_QS_100K_10M: [u64; 4] = [ 1125899906629633, // 50 bits 1125899905744897, // 50 bits 1125899905351681, // 50 bits 1125899903827969, // 50 bits ]; -const KAHE_NUM_PACKING_100K_10M: usize = 3; -const KAHE_NUM_PUBLIC_POLY_100K_10M: usize = 5; const AHE_LOG_N_100K_10M: u64 = 12; const AHE_T_100K_10M: u64 = 6582404323; const AHE_QS_100K_10M: [u64; 2] = [281474976546817, 281474975662081]; // 96 bits total @@ -71,16 +69,14 @@ const AHE_S_FLOOD_100K_10M: f64 = 3.0834e+16; /// - input of length 10M with 32-bit domain /// - max number of clients 10M /// - max number of decryptors 100 -const KAHE_LOG_N_10M_10M: u64 = 14; -const KAHE_LOG_T_10M_10M: u64 = 224; +const KAHE_LOG_N_10M_10M: usize = 14; +const KAHE_LOG_T_10M_10M: usize = 224; const KAHE_QS_10M_10M: [u64; 4] = [ 2305843009211596801, // 61 bits 2305843009211400193, // 61 bits 2305843009210515457, // 61 bits 2305843009210023937, // 61 bits ]; -const KAHE_NUM_PACKING_10M_10M: usize = 4; -const KAHE_NUM_PUBLIC_POLY_10M_10M: usize = 153; const AHE_LOG_N_10M_10M: u64 = 12; const AHE_T_10M_10M: u64 = 7121256483; const AHE_QS_10M_10M: [u64; 2] = [281474976546817, 281474975662081]; // 96 bits total @@ -89,28 +85,40 @@ const AHE_S_FLOOD_10M_10M: f64 = 3.0834e+16; /// Creates a pair (ShellKaheConfig, ShellAheConfig) to be used to instantiate /// KAHE and AHE schemes for the given protocol setting. pub fn create_shell_configs( - input_length: u64, - input_domain: u64, - max_num_clients: usize, - max_num_decryptors: usize, + aggregation_config: &AggregationConfig, ) -> Result<(ShellKaheConfig, ShellAheConfig), status::StatusError> { - if input_length <= 1000 - && input_domain <= (1u64 << 32) - && max_num_clients <= 10_000_000 - && max_num_decryptors <= 100 + // Use heuristics to select parameters. + let total_input_length: i64 = aggregation_config + .vector_lengths_and_bounds + .values() + .map(|(length, _)| *length as i64) + .sum(); + let max_input_bound = aggregation_config + .vector_lengths_and_bounds + .values() + .map(|(_, bound)| bound) + .max() + .unwrap(); + + if total_input_length <= 1000 + && *max_input_bound <= (1i64 << 32) + && aggregation_config.max_number_of_clients <= 10_000_000 + && aggregation_config.max_number_of_decryptors <= 100 { + let packed_vector_configs = generate_packing_config(KAHE_LOG_T_1K_10M, aggregation_config)?; + let kahe_total_num_coeffs: usize = packed_vector_configs + .values() + .map(|packed_vector_cfg| packed_vector_cfg.num_packed_coeffs as usize) + .sum(); + let kahe_num_coeffs = 1 << KAHE_LOG_N_1K_10M; return Ok(( - ShellKaheConfig::new( - input_domain, - max_num_clients, - KAHE_NUM_PACKING_1K_10M, - KAHE_NUM_PUBLIC_POLY_1K_10M, - KaheRnsConfig { - log_n: KAHE_LOG_N_1K_10M, - log_t: KAHE_LOG_T_1K_10M, - qs: KAHE_QS_1K_10M.to_vec(), - }, - )?, + ShellKaheConfig { + log_n: KAHE_LOG_N_1K_10M, + moduli: KAHE_QS_1K_10M.to_vec(), + log_t: KAHE_LOG_T_1K_10M, + num_public_polynomials: divide_and_roundup(kahe_total_num_coeffs, kahe_num_coeffs), + packed_vector_configs, + }, ShellAheConfig { log_n: AHE_LOG_N_1K_10M, t: AHE_T_1K_10M, @@ -120,23 +128,26 @@ pub fn create_shell_configs( )); } - if input_length <= 100_000 - && input_domain <= (1u64 << 32) - && max_num_clients <= 10_000_000 - && max_num_decryptors <= 100 + if total_input_length <= 100_000 + && *max_input_bound <= (1i64 << 32) + && aggregation_config.max_number_of_clients <= 10_000_000 + && aggregation_config.max_number_of_decryptors <= 100 { + let packed_vector_configs = + generate_packing_config(KAHE_LOG_T_100K_10M, aggregation_config)?; + let kahe_total_num_coeffs: usize = packed_vector_configs + .values() + .map(|packed_vector_cfg| packed_vector_cfg.num_packed_coeffs as usize) + .sum(); + let kahe_num_coeffs = 1 << KAHE_LOG_N_100K_10M; return Ok(( - ShellKaheConfig::new( - input_domain, - max_num_clients, - KAHE_NUM_PACKING_100K_10M, - KAHE_NUM_PUBLIC_POLY_100K_10M, - KaheRnsConfig { - log_n: KAHE_LOG_N_100K_10M, - log_t: KAHE_LOG_T_100K_10M, - qs: KAHE_QS_100K_10M.to_vec(), - }, - )?, + ShellKaheConfig { + log_n: KAHE_LOG_N_100K_10M, + moduli: KAHE_QS_100K_10M.to_vec(), + log_t: KAHE_LOG_T_100K_10M, + num_public_polynomials: divide_and_roundup(kahe_total_num_coeffs, kahe_num_coeffs), + packed_vector_configs, + }, ShellAheConfig { log_n: AHE_LOG_N_100K_10M, t: AHE_T_100K_10M, @@ -146,23 +157,26 @@ pub fn create_shell_configs( )); } - if input_length <= 10_000_000 - && input_domain <= (1u64 << 32) - && max_num_clients <= 10000000 - && max_num_decryptors <= 100 + if total_input_length <= 10_000_000 + && *max_input_bound <= (1i64 << 32) + && aggregation_config.max_number_of_clients <= 10000000 + && aggregation_config.max_number_of_decryptors <= 100 { + let packed_vector_configs = + generate_packing_config(KAHE_LOG_T_10M_10M, aggregation_config)?; + let kahe_total_num_coeffs: usize = packed_vector_configs + .values() + .map(|packed_vector_cfg| packed_vector_cfg.num_packed_coeffs as usize) + .sum(); + let kahe_num_coeffs = 1 << KAHE_LOG_N_10M_10M; return Ok(( - ShellKaheConfig::new( - input_domain, - max_num_clients, - KAHE_NUM_PACKING_10M_10M, - KAHE_NUM_PUBLIC_POLY_10M_10M, - KaheRnsConfig { - log_n: KAHE_LOG_N_10M_10M, - log_t: KAHE_LOG_T_10M_10M, - qs: KAHE_QS_10M_10M.to_vec(), - }, - )?, + ShellKaheConfig { + log_n: KAHE_LOG_N_10M_10M, + moduli: KAHE_QS_10M_10M.to_vec(), + log_t: KAHE_LOG_T_10M_10M, + num_public_polynomials: divide_and_roundup(kahe_total_num_coeffs, kahe_num_coeffs), + packed_vector_configs, + }, ShellAheConfig { log_n: AHE_LOG_N_10M_10M, t: AHE_T_10M_10M, @@ -173,7 +187,7 @@ pub fn create_shell_configs( } Err(status::invalid_argument(format!( - "input setting is not supported: input_length = {}, input_domain = {}, max_num_clients = {}, max_num_decryptors = {}", - input_length, input_domain, max_num_clients, max_num_decryptors + "input setting is not supported: aggregation_config = {:?}", + aggregation_config ))) } diff --git a/willow/src/shell/parameters_generation.rs b/willow/src/shell/parameters_generation.rs index 6cea3c6..96de16c 100644 --- a/willow/src/shell/parameters_generation.rs +++ b/willow/src/shell/parameters_generation.rs @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use protobuf::{proto, ProtoStr}; -use shell_parameters_rust_proto::{PackedVectorConfig, ShellKaheConfig, ShellKahePackingConfig}; +use kahe::PackedVectorConfig; +use std::collections::HashMap; use willow_api_common::AggregationConfig; /// Generating KAHE and AHE parameters given the Willow protocol configuration. @@ -24,7 +24,7 @@ const MAX_PACKING_BASE_BITS: usize = 63; const BIG_INT_BITS: usize = 256; // Returns ceil(x / y). -fn divide_and_roundup(x: usize, y: usize) -> usize { +pub fn divide_and_roundup(x: usize, y: usize) -> usize { (x + y - 1) / y } @@ -34,7 +34,7 @@ fn divide_and_roundup(x: usize, y: usize) -> usize { pub fn generate_packing_config( plaintext_bits: usize, agg_config: &AggregationConfig, -) -> Result { +) -> Result, status::StatusError> { if plaintext_bits == 0 { return Err(status::invalid_argument("`plaintext_bits` must be positive.")); } @@ -47,7 +47,7 @@ pub fn generate_packing_config( if agg_config.max_number_of_clients <= 0 { return Err(status::invalid_argument("`max_number_of_clients` must be positive.")); } - let mut packing_config = ShellKahePackingConfig::new(); + let mut packing_configs = HashMap::::new(); for (id, (length, bound)) in agg_config.vector_lengths_and_bounds.iter() { if *length <= 0 { return Err(status::invalid_argument(format!( @@ -83,14 +83,14 @@ pub fn generate_packing_config( ))); } let num_packed_coeffs = divide_and_roundup(*length as usize, dimension); - packing_config.packed_vectors_mut().insert( - ProtoStr::from_str(&id), - proto!(PackedVectorConfig { - base: base as i64, - dimension: dimension as i64, - num_packed_coeffs: num_packed_coeffs as i64, - }), + packing_configs.insert( + id.clone(), + PackedVectorConfig { + base: base as u64, + dimension: dimension as u64, + num_packed_coeffs: num_packed_coeffs as u64, + }, ); } - Ok(packing_config) + Ok(packing_configs) } diff --git a/willow/src/shell/parameters_utils.rs b/willow/src/shell/parameters_utils.rs new file mode 100644 index 0000000..c309d11 --- /dev/null +++ b/willow/src/shell/parameters_utils.rs @@ -0,0 +1,125 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use kahe::PackedVectorConfig; +use kahe_shell::ShellKaheConfig; +use protobuf::{proto, ProtoStr}; +use shell_parameters_rust_proto::{ + PackedVectorConfigProto, PackedVectorConfigProtoView, ShellKaheConfigProto, + ShellKaheConfigProtoView, +}; +use std::collections::HashMap; + +/// This file contains some utility functions for working with Willow parameters: +/// - Conversions between Rust structs and their corresponding protos. + +/// Convert a rust struct `PackedVectorConfig` to the corresponding proto. +pub fn packed_vector_config_to_proto(config: &PackedVectorConfig) -> PackedVectorConfigProto { + proto!(PackedVectorConfigProto { + base: config.base as i64, + dimension: config.dimension as i64, + num_packed_coeffs: config.num_packed_coeffs as i64, + }) +} + +/// Convert a `PackedVectorConfigProto` to its corresponding rust struct. +pub fn packed_vector_config_from_proto(proto: PackedVectorConfigProtoView) -> PackedVectorConfig { + PackedVectorConfig { + base: proto.base() as u64, + dimension: proto.dimension() as u64, + num_packed_coeffs: proto.num_packed_coeffs() as u64, + } +} + +/// Convert a rust struct `ShellKaheConfig` to the corresponding proto. +pub fn kahe_config_to_proto(config: &ShellKaheConfig) -> ShellKaheConfigProto { + proto!(ShellKaheConfigProto { + log_n: config.log_n as i64, + moduli: config.moduli.clone().into_iter(), + log_t: config.log_t as i64, + num_public_polynomials: config.num_public_polynomials as i64, + packed_vectors: config + .packed_vector_configs + .iter() + .map(|(id, packed_vector_config)| { + (ProtoStr::from_str(&id), packed_vector_config_to_proto(&packed_vector_config)) + }) + .collect::>() + .into_iter(), + }) +} + +/// Convert a `ShellKaheConfigProto` to the corresponding rust struct. +pub fn kahe_config_from_proto( + proto: ShellKaheConfigProtoView, +) -> Result { + Ok(ShellKaheConfig { + log_n: proto.log_n() as usize, + moduli: proto.moduli().iter().collect(), + log_t: proto.log_t() as usize, + num_public_polynomials: proto.num_public_polynomials() as usize, + packed_vector_configs: proto + .packed_vectors() + .iter() + .map(|(id, packed_vector_config)| { + if let Ok(id_str) = id.to_str() { + Ok((id_str.to_string(), packed_vector_config_from_proto(packed_vector_config))) + } else { + Err(status::invalid_argument("invalid id in `packed_vectors`.")) + } + }) + .collect::, _>>()?, + }) +} + +#[cfg(test)] +mod test { + use super::*; + use googletest::prelude::*; + + #[gtest] + fn test_packed_vector_config_proto_roundtrip() -> googletest::Result<()> { + let config = PackedVectorConfig { base: 8u64, dimension: 2u64, num_packed_coeffs: 1024u64 }; + let proto = packed_vector_config_to_proto(&config); + let config_from_proto = packed_vector_config_from_proto(proto.as_view()); + verify_eq!(config_from_proto, config) + } + + #[gtest] + fn test_kahe_config_proto_roundtrip() -> googletest::Result<()> { + let config = ShellKaheConfig { + log_n: 10usize, + moduli: vec![65537u64, 12289u64], + log_t: 5usize, + num_public_polynomials: 2usize, + packed_vector_configs: HashMap::from([ + ( + String::from("vector0"), + PackedVectorConfig { base: 16u64, dimension: 8u64, num_packed_coeffs: 1024u64 }, + ), + ( + String::from("vector1"), + PackedVectorConfig { + base: 65536u64, + dimension: 1u64, + num_packed_coeffs: 16u64, + }, + ), + ]), + }; + let proto = kahe_config_to_proto(&config); + let config_from_proto = kahe_config_from_proto(proto.as_view())?; + verify_eq!(config_from_proto, config) + } +} diff --git a/willow/src/testing_utils/BUILD b/willow/src/testing_utils/BUILD index ade60d2..466761f 100644 --- a/willow/src/testing_utils/BUILD +++ b/willow/src/testing_utils/BUILD @@ -47,9 +47,12 @@ rust_library( testonly = 1, srcs = ["shell_testing_parameters.rs"], deps = [ + "//shell_wrapper:kahe", "//shell_wrapper:status", + "//willow/src/api:willow_api_common", "//willow/src/shell:ahe_shell", "//willow/src/shell:kahe_shell", + "//willow/src/shell:shell_parameters_generation", "//willow/src/traits:ahe_traits", "//willow/src/traits:kahe_traits", ], @@ -65,6 +68,7 @@ rust_library( deps = [ "@crate_index//:rand", "//shell_wrapper:status", + "//willow/src/api:willow_api_common", "//willow/src/shell:kahe_shell", "//willow/src/shell:single_thread_hkdf", "//willow/src/shell:vahe_shell", diff --git a/willow/src/testing_utils/shell_testing_parameters.rs b/willow/src/testing_utils/shell_testing_parameters.rs index 065c307..8252b1e 100644 --- a/willow/src/testing_utils/shell_testing_parameters.rs +++ b/willow/src/testing_utils/shell_testing_parameters.rs @@ -13,13 +13,18 @@ // limitations under the License. use ahe_shell::ShellAheConfig; -use kahe_shell::{KaheRnsConfig, ShellKaheConfig}; +use kahe::PackedVectorConfig; +use kahe_shell::ShellKaheConfig; +use shell_parameters_generation::{divide_and_roundup, generate_packing_config}; +use std::collections::HashMap; +use willow_api_common::AggregationConfig; -/// Creates an KAHE RNS configuration with the given plaintext modulus bits, by +/// Creates an KAHE configuration with the given plaintext modulus bits, by /// looking up some pre-generated configurations. -pub fn make_kahe_rns_config( +pub fn make_kahe_config_for( plaintext_modulus_bits: usize, -) -> Result { + packed_vector_configs: HashMap, +) -> Result { // Configurations below come from: // google3/experimental/users/baiyuli/async_rlwe_secagg/parameters.cc, // originally generated with: @@ -29,36 +34,60 @@ pub fn make_kahe_rns_config( // log2(kTailBoundMultiplier) - log2(kPrgErrorS) // = composite_modulus_bits - 7 match plaintext_modulus_bits { - 17 => Ok(KaheRnsConfig { log_n: 10, log_t: 17, qs: vec![16760833] }), - 39 => Ok(KaheRnsConfig { log_n: 11, log_t: 39, qs: vec![70368744067073] }), + 17 => { + let total_num_coeffs = + packed_vector_configs.values().map(|cfg| cfg.num_packed_coeffs as usize).sum(); + Ok(ShellKaheConfig { + log_n: 10, + moduli: vec![16760833u64], + log_t: 17, + packed_vector_configs, + num_public_polynomials: divide_and_roundup(total_num_coeffs, 1 << 10), + }) + } + 39 => { + let total_num_coeffs = + packed_vector_configs.values().map(|cfg| cfg.num_packed_coeffs as usize).sum(); + Ok(ShellKaheConfig { + log_n: 11, + moduli: vec![70368744067073u64], + log_t: 39, + packed_vector_configs, + num_public_polynomials: divide_and_roundup(total_num_coeffs, 1 << 11), + }) + } 93 => { - Ok(KaheRnsConfig { log_n: 12, log_t: 93, qs: vec![1125899906826241, 1125899906629633] }) + let total_num_coeffs = + packed_vector_configs.values().map(|cfg| cfg.num_packed_coeffs as usize).sum(); + Ok(ShellKaheConfig { + log_n: 12, + moduli: vec![1125899906826241u64, 1125899906629633u64], + log_t: 93, + packed_vector_configs, + num_public_polynomials: divide_and_roundup(total_num_coeffs, 1 << 12), + }) } _ => Err(status::invalid_argument(format!( - "No RNS configuration for plaintext_modulus_bits = {}", + "No KAHE configuration for plaintext_modulus_bits = {}", plaintext_modulus_bits ))), } } +pub fn set_kahe_num_public_polynomials(kahe_config: &mut ShellKaheConfig) -> () { + let num_coeffs_per_poly = 1 << kahe_config.log_n; + let total_num_coeffs = + kahe_config.packed_vector_configs.values().map(|cfg| cfg.num_packed_coeffs as usize).sum(); + kahe_config.num_public_polynomials = divide_and_roundup(total_num_coeffs, num_coeffs_per_poly); +} + /// Creates a sample KAHE configuration, for quick tests that need just any /// valid configuration. -pub fn make_kahe_config() -> ShellKaheConfig { +pub fn make_kahe_config(aggregation_config: &AggregationConfig) -> ShellKaheConfig { const PLAINTEXT_MODULUS_BITS: usize = 93; - const INPUT_DOMAIN: u64 = 10; - const MAX_NUM_CLIENTS: usize = 100_000; - const NUM_PACKING: usize = 2; - const NUM_PUBLIC_POLYNOMIALS: usize = 1; - - let rns_config = make_kahe_rns_config(PLAINTEXT_MODULUS_BITS).unwrap(); - ShellKaheConfig::new( - INPUT_DOMAIN, - MAX_NUM_CLIENTS, - NUM_PACKING, - NUM_PUBLIC_POLYNOMIALS, - rns_config, - ) - .unwrap() + let packed_vector_configs = + generate_packing_config(PLAINTEXT_MODULUS_BITS, aggregation_config).unwrap(); + make_kahe_config_for(PLAINTEXT_MODULUS_BITS, packed_vector_configs).unwrap() } /// Creates an AHE configuration with 69-bit main modulus and 64-bit RNS moduli. diff --git a/willow/src/testing_utils/testing_utils.rs b/willow/src/testing_utils/testing_utils.rs index fadd121..8c953a9 100644 --- a/willow/src/testing_utils/testing_utils.rs +++ b/willow/src/testing_utils/testing_utils.rs @@ -20,6 +20,7 @@ use shell_testing_parameters::{make_ahe_config, make_kahe_config}; use single_thread_hkdf::Seed; use vahe_shell::ShellVahe; use vahe_traits::{Recover, VaheBase}; +use willow_api_common::AggregationConfig; use willow_v1_client::WillowV1Client; use willow_v1_common::{WillowClientMessage, WillowCommon}; @@ -43,10 +44,11 @@ pub fn generate_random_signed_vector(num_values: usize, max_absolute_value: i64) /// Creates a `WillowCommon` for SHELL with the default AHE/KAHE configurations /// and the given public seeds. pub fn create_willow_common( + aggregation_config: &AggregationConfig, public_kahe_seed: &Seed, public_ahe_seed: &Seed, ) -> WillowCommon { - let kahe = ShellKahe::new(make_kahe_config(), public_kahe_seed).unwrap(); + let kahe = ShellKahe::new(make_kahe_config(aggregation_config), public_kahe_seed).unwrap(); let vahe = ShellVahe::new(make_ahe_config(), public_ahe_seed).unwrap(); WillowCommon { kahe, vahe } } diff --git a/willow/src/willow_v1/BUILD b/willow/src/willow_v1/BUILD index 844c39a..961c08f 100644 --- a/willow/src/willow_v1/BUILD +++ b/willow/src/willow_v1/BUILD @@ -54,6 +54,7 @@ rust_test( crate = ":willow_v1_client", deps = [ "@crate_index//:googletest", + "//willow/src/api:willow_api_common", "//willow/src/shell:single_thread_hkdf", "//willow/src/testing_utils", "//willow/src/traits:prng_traits", diff --git a/willow/src/willow_v1/client.rs b/willow/src/willow_v1/client.rs index 2d2de89..c197eb7 100644 --- a/willow/src/willow_v1/client.rs +++ b/willow/src/willow_v1/client.rs @@ -66,34 +66,49 @@ mod test { use super::*; use ahe_traits::{AheKeygen, PartialDec}; - use googletest::{gtest, verify_eq}; + use googletest::prelude::container_eq; + use googletest::{gtest, verify_eq, verify_that}; use kahe_traits::{KaheDecrypt, TrySecretKeyFrom}; use prng_traits::SecurePrng; use single_thread_hkdf::SingleThreadHkdfPrng; + use std::collections::HashMap; use testing_utils::create_willow_common; use vahe_traits::{Recover, VaheBase}; + use willow_api_common::AggregationConfig; #[gtest] fn test_create_client_message() -> googletest::Result<()> { + let default_id = String::from("default"); + let aggregation_config = AggregationConfig { + vector_lengths_and_bounds: HashMap::from([(default_id.clone(), (16, 10))]), + max_number_of_decryptors: 1, + max_number_of_clients: 1, + max_decryptor_dropouts: 0, + session_id: String::from("test"), + willow_version: (1, 0), + }; // Generate public parameters for KAHE and AHE. let public_kahe_seed = SingleThreadHkdfPrng::generate_seed()?; let public_ahe_seed = SingleThreadHkdfPrng::generate_seed()?; // Create a client. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let client_seed = SingleThreadHkdfPrng::generate_seed()?; let prng = SingleThreadHkdfPrng::create(&client_seed)?; let mut client = WillowV1Client { common: common, prng: prng }; // Generate AHE keys. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed()?; let mut prng = SingleThreadHkdfPrng::create(&seed)?; let (sk_share, pk_share, _) = common.vahe.key_gen(&mut prng)?; let public_key = common.vahe.aggregate_public_key_shares(&[pk_share])?; // Create client message. - let client_plaintext = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; + let client_plaintext = HashMap::from([( + default_id.clone(), + vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1], + )]); let client_message = client.create_client_message(&client_plaintext, &public_key)?; // Decrypt client message. @@ -109,37 +124,58 @@ mod test { let decrypted_plaintext = common.kahe.decrypt(&client_message.kahe_ciphertext, &decrypted_kahe_key)?; - verify_eq!(decrypted_plaintext[..client_plaintext.len()], client_plaintext) + verify_that!(decrypted_plaintext.keys().collect::>(), container_eq([&default_id]))?; + let client_plaintext_length = client_plaintext.get(&default_id).unwrap().len(); + verify_eq!( + decrypted_plaintext.get(&default_id).unwrap()[..client_plaintext_length], + client_plaintext.get(&default_id).unwrap()[..] + ) } #[gtest] fn test_client_messages_are_aggregatable() -> googletest::Result<()> { + let default_id = String::from("default"); + let aggregation_config = AggregationConfig { + vector_lengths_and_bounds: HashMap::from([(default_id.clone(), (16, 10))]), + max_number_of_decryptors: 1, + max_number_of_clients: 2, + max_decryptor_dropouts: 0, + session_id: String::from("test"), + willow_version: (1, 0), + }; + // Generate public parameters for KAHE and AHE. let public_kahe_seed = SingleThreadHkdfPrng::generate_seed()?; let public_ahe_seed = SingleThreadHkdfPrng::generate_seed()?; // Create a client. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let client1_seed = SingleThreadHkdfPrng::generate_seed()?; let prng = SingleThreadHkdfPrng::create(&client1_seed)?; let mut client1 = WillowV1Client { common: common, prng: prng }; // Create a second client. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let client2_seed = SingleThreadHkdfPrng::generate_seed()?; let prng = SingleThreadHkdfPrng::create(&client2_seed)?; let mut client2 = WillowV1Client { common: common, prng: prng }; // Generate AHE keys. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed()?; let mut prng = SingleThreadHkdfPrng::create(&seed)?; let (sk_share, pk_share, _) = common.vahe.key_gen(&mut prng)?; let public_key = common.vahe.aggregate_public_key_shares(&[pk_share])?; // Create client messages. - let client1_plaintext = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; - let client2_plaintext = vec![1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0]; + let client1_plaintext = HashMap::from([( + default_id.clone(), + vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1], + )]); + let client2_plaintext = HashMap::from([( + default_id.clone(), + vec![1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0], + )]); let expected_output = vec![2, 3, 5, 7, 10, 14, 10, 9, 11, 11, 14, 8, 6, 9, 1]; let mut client_message = client1.create_client_message(&client1_plaintext, &public_key)?; let extra_message = client2.create_client_message(&client2_plaintext, &public_key)?; @@ -167,6 +203,11 @@ mod test { let decrypted_plaintext = common.kahe.decrypt(&client_message.kahe_ciphertext, &decrypted_kahe_key)?; - verify_eq!(decrypted_plaintext[..expected_output.len()], expected_output) + verify_that!(decrypted_plaintext.keys().collect::>(), container_eq([&default_id]))?; + let client_plaintext_length = client1_plaintext.get(&default_id).unwrap().len(); + verify_eq!( + decrypted_plaintext.get(&default_id).unwrap()[..client_plaintext_length], + expected_output + ) } } diff --git a/willow/src/zk/rlwe_relation.rs b/willow/src/zk/rlwe_relation.rs index fb2339b..efc97b5 100644 --- a/willow/src/zk/rlwe_relation.rs +++ b/willow/src/zk/rlwe_relation.rs @@ -1061,7 +1061,7 @@ mod tests { let a_buffer = [1, 2, 3, 4]; let expected_result = [1 as u128, 2 as u128, 3 as u128, 4 as u128]; let n = 4; - let a = read_small_rns_polynomial_from_buffer(&a_buffer, 2, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&a_buffer, n as u64, &moduli)?; let a_unpacked = unpack_rns_polynomial(&context, &a, n)?; assert!(a_unpacked.eq(&expected_result)); Ok(()) @@ -1102,8 +1102,8 @@ mod tests { let poly_buffer = [1; 4096]; - let a = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; let statement = RlweRelationProofStatement { n: n, @@ -1140,8 +1140,8 @@ mod tests { let poly_buffer = [1; 4096]; - let a = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; let statement = RlweRelationProofStatement { n: n, @@ -1177,8 +1177,8 @@ mod tests { let poly_buffer = [1; 4096]; - let a = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; let statement = RlweRelationProofStatement { n: n, @@ -1207,8 +1207,8 @@ mod tests { let poly_buffer = [1; 4096]; - let a = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&poly_buffer, 12, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&poly_buffer, n as u64, &moduli)?; let statement = RlweRelationProofStatement { n: n, @@ -1294,11 +1294,11 @@ mod tests { let c_buffer = [5, -8, 9, 17]; let v_buffer = [-1, -1, 4, 0]; - let a = read_small_rns_polynomial_from_buffer(&a_buffer, 2, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&c_buffer, 2, &moduli)?; - let r = read_small_rns_polynomial_from_buffer(&r_buffer, 2, &moduli)?; - let e = read_small_rns_polynomial_from_buffer(&e_buffer, 2, &moduli)?; - let v = read_small_rns_polynomial_from_buffer(&v_buffer, 2, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&a_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&c_buffer, n as u64, &moduli)?; + let r = read_small_rns_polynomial_from_buffer(&r_buffer, n as u64, &moduli)?; + let e = read_small_rns_polynomial_from_buffer(&e_buffer, n as u64, &moduli)?; + let v = read_small_rns_polynomial_from_buffer(&v_buffer, n as u64, &moduli)?; let statement = RlweRelationProofStatement { n: n, @@ -1340,11 +1340,11 @@ mod tests { let c_buffer = [5, -8, 9, 17]; let v_buffer = [-1, -1, 4, 0]; - let a = read_small_rns_polynomial_from_buffer(&a_buffer, 2, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&c_buffer, 2, &moduli)?; - let r = read_small_rns_polynomial_from_buffer(&r_buffer, 2, &moduli)?; - let e = read_small_rns_polynomial_from_buffer(&e_buffer, 2, &moduli)?; - let v = read_small_rns_polynomial_from_buffer(&v_buffer, 2, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&a_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&c_buffer, n as u64, &moduli)?; + let r = read_small_rns_polynomial_from_buffer(&r_buffer, n as u64, &moduli)?; + let e = read_small_rns_polynomial_from_buffer(&e_buffer, n as u64, &moduli)?; + let v = read_small_rns_polynomial_from_buffer(&v_buffer, n as u64, &moduli)?; let statement = RlweRelationProofStatement { n: n, @@ -1386,11 +1386,11 @@ mod tests { let c_buffer = [5, -8, 9, 17]; let v_buffer = [-1, -1, 4, 0]; - let a = read_small_rns_polynomial_from_buffer(&a_buffer, 2, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&c_buffer, 2, &moduli)?; - let r = read_small_rns_polynomial_from_buffer(&r_buffer, 2, &moduli)?; - let e = read_small_rns_polynomial_from_buffer(&e_buffer, 2, &moduli)?; - let v = read_small_rns_polynomial_from_buffer(&v_buffer, 2, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&a_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&c_buffer, n as u64, &moduli)?; + let r = read_small_rns_polynomial_from_buffer(&r_buffer, n as u64, &moduli)?; + let e = read_small_rns_polynomial_from_buffer(&e_buffer, n as u64, &moduli)?; + let v = read_small_rns_polynomial_from_buffer(&v_buffer, n as u64, &moduli)?; let mut statement = RlweRelationProofStatement { n: n, @@ -1455,11 +1455,11 @@ mod tests { let c_buffer = [5, -8, 9, 17]; let v_buffer = [-1, -1, 4, 0]; - let a = read_small_rns_polynomial_from_buffer(&a_buffer, 2, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&c_buffer, 2, &moduli)?; - let r = read_small_rns_polynomial_from_buffer(&r_buffer, 2, &moduli)?; - let e = read_small_rns_polynomial_from_buffer(&e_buffer, 2, &moduli)?; - let v = read_small_rns_polynomial_from_buffer(&v_buffer, 2, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&a_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&c_buffer, n as u64, &moduli)?; + let r = read_small_rns_polynomial_from_buffer(&r_buffer, n as u64, &moduli)?; + let e = read_small_rns_polynomial_from_buffer(&e_buffer, n as u64, &moduli)?; + let v = read_small_rns_polynomial_from_buffer(&v_buffer, n as u64, &moduli)?; let mut statement = RlweRelationProofStatement { n: n, @@ -1535,11 +1535,11 @@ mod tests { let c_buffer = [5, -8, 9, 17]; let v_buffer = [-1, -1, 4, 0]; - let a = read_small_rns_polynomial_from_buffer(&a_buffer, 2, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&c_buffer, 2, &moduli)?; - let r = read_small_rns_polynomial_from_buffer(&r_buffer, 2, &moduli)?; - let e = read_small_rns_polynomial_from_buffer(&e_buffer, 2, &moduli)?; - let v = read_small_rns_polynomial_from_buffer(&v_buffer, 2, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&a_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&c_buffer, n as u64, &moduli)?; + let r = read_small_rns_polynomial_from_buffer(&r_buffer, n as u64, &moduli)?; + let e = read_small_rns_polynomial_from_buffer(&e_buffer, n as u64, &moduli)?; + let v = read_small_rns_polynomial_from_buffer(&v_buffer, n as u64, &moduli)?; let mut statement = RlweRelationProofStatement { n: n, @@ -1618,11 +1618,11 @@ mod tests { let c_buffer = [5, -8, 9, 17]; let v_buffer = [-1, -1, 4, 0]; - let a = read_small_rns_polynomial_from_buffer(&a_buffer, 2, &moduli)?; - let c = read_small_rns_polynomial_from_buffer(&c_buffer, 2, &moduli)?; - let r = read_small_rns_polynomial_from_buffer(&r_buffer, 2, &moduli)?; - let e = read_small_rns_polynomial_from_buffer(&e_buffer, 2, &moduli)?; - let v = read_small_rns_polynomial_from_buffer(&v_buffer, 2, &moduli)?; + let a = read_small_rns_polynomial_from_buffer(&a_buffer, n as u64, &moduli)?; + let c = read_small_rns_polynomial_from_buffer(&c_buffer, n as u64, &moduli)?; + let r = read_small_rns_polynomial_from_buffer(&r_buffer, n as u64, &moduli)?; + let e = read_small_rns_polynomial_from_buffer(&e_buffer, n as u64, &moduli)?; + let v = read_small_rns_polynomial_from_buffer(&v_buffer, n as u64, &moduli)?; let mut statement = RlweRelationProofStatement { n: n, diff --git a/willow/tests/BUILD b/willow/tests/BUILD index d6c5cc3..78239d9 100644 --- a/willow/tests/BUILD +++ b/willow/tests/BUILD @@ -27,6 +27,7 @@ rust_test( "@crate_index//:googletest", "//shell_wrapper:status", "//shell_wrapper:status_matchers_rs", + "//willow/src/api:willow_api_common", "//willow/src/shell:ahe_shell", "//willow/src/shell:kahe_shell", "//willow/src/shell:parameters_shell", diff --git a/willow/tests/willow_v1_shell.rs b/willow/tests/willow_v1_shell.rs index 693eeca..e740c7c 100644 --- a/willow/tests/willow_v1_shell.rs +++ b/willow/tests/willow_v1_shell.rs @@ -14,6 +14,7 @@ use client_traits::SecureAggregationClient; use decryptor_traits::SecureAggregationDecryptor; +use googletest::prelude::container_eq; use googletest::{gtest, verify_eq, verify_that}; use kahe_shell::ShellKahe; use parameters_shell::create_shell_configs; @@ -22,42 +23,64 @@ use server_traits::SecureAggregationServer; use single_thread_hkdf::SingleThreadHkdfPrng; use status::StatusErrorCode; use status_matchers_rs::status_is; +use std::collections::HashMap; use testing_utils::{create_willow_common, generate_random_unsigned_vector}; use vahe_shell::ShellVahe; use verifier_traits::SecureAggregationVerifier; +use willow_api_common::AggregationConfig; use willow_v1_client::WillowV1Client; use willow_v1_common::WillowCommon; use willow_v1_decryptor::{DecryptorState, WillowV1Decryptor}; use willow_v1_server::{ServerState, WillowV1Server}; use willow_v1_verifier::{VerifierState, WillowV1Verifier}; +/// Generates an AggregationConfig for test cases in this file. +fn generate_aggregation_config( + vector_id: String, + vector_length: isize, + vector_bound: i64, + max_number_of_decryptors: i64, + max_number_of_clients: i64, +) -> AggregationConfig { + AggregationConfig { + vector_lengths_and_bounds: HashMap::from([(vector_id, (vector_length, vector_bound))]), + max_number_of_decryptors, + max_number_of_clients, + max_decryptor_dropouts: 0, + session_id: String::from("test"), + willow_version: (1, 0), + } +} + /// Encrypt and decrypt with a single decryptor and single client. #[gtest] fn encrypt_decrypt_one() -> googletest::Result<()> { + let default_id = String::from("default"); + let aggregation_config = generate_aggregation_config(default_id.clone(), 16, 10, 1, 1); let public_kahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let public_ahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); // Create client. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut client = WillowV1Client { common, prng }; // Create decryptor, which needs its own `common` (with same public polynomials // generated from the seeds) and `prng`. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::new(); let mut decryptor = WillowV1Decryptor { common, prng }; // Create server. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let server = WillowV1Server { common }; let mut server_state = ServerState::new(); // Create verifier. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let verifier = WillowV1Verifier { common }; let mut verifier_state = VerifierState::new(); @@ -71,7 +94,8 @@ fn encrypt_decrypt_one() -> googletest::Result<()> { let public_key = server.create_decryptor_public_key(&server_state).unwrap(); // Client encrypts. - let client_plaintext = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; + let client_plaintext = + HashMap::from([(default_id.clone(), vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1])]); let client_message = client.create_client_message(&client_plaintext, &public_key).unwrap(); // The client message is split and handled by the server and verifier. @@ -93,19 +117,29 @@ fn encrypt_decrypt_one() -> googletest::Result<()> { let aggregation_result = server.recover_aggregation_result(&server_state).unwrap(); // Check that the (padded) result matches the client plaintext. - verify_eq!(aggregation_result[..client_plaintext.len()], client_plaintext) + verify_that!(aggregation_result.keys().collect::>(), container_eq([&default_id]))?; + let client_plaintext_length = client_plaintext.get(&default_id).unwrap().len(); + verify_eq!( + aggregation_result.get(&default_id).unwrap()[..client_plaintext_length], + client_plaintext.get(&default_id).unwrap()[..] + ) } // Encrypt and decrypt with multiple clients and a single decryptor. #[gtest] fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { + const NUM_CLIENTS: i64 = 10; + let default_id = String::from("default"); + let aggregation_config = + generate_aggregation_config(default_id.clone(), 16, 10, 1, NUM_CLIENTS); + let public_kahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let public_ahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); // Create clients. let mut clients = vec![]; - for _ in 0..10 { - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + for _ in 0..NUM_CLIENTS { + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let client = WillowV1Client { common, prng }; @@ -114,19 +148,19 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { // Create decryptor, which needs its own `common` (with same public polynomials // generated from the seeds) and `prng`. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::new(); let mut decryptor = WillowV1Decryptor { common, prng }; // Create server. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let server = WillowV1Server { common }; let mut server_state = ServerState::new(); // Create verifier. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let verifier = WillowV1Verifier { common }; let mut verifier_state = VerifierState::new(); @@ -142,10 +176,11 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { // Clients encrypt. let mut expected_output = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; for client in &mut clients { - let client_plaintext = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; + let client_input_values = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; for i in 0..expected_output.len() { - expected_output[i] += client_plaintext[i]; + expected_output[i] += client_input_values[i]; } + let client_plaintext = HashMap::from([(default_id.clone(), client_input_values)]); let client_message = client.create_client_message(&client_plaintext, &public_key).unwrap(); // The client message is split and handled by the server and verifier. let (ciphertext_contribution, decryption_request_contribution) = @@ -167,19 +202,29 @@ fn encrypt_decrypt_multiple_clients() -> googletest::Result<()> { let aggregation_result = server.recover_aggregation_result(&server_state).unwrap(); // Check that the (padded) result matches the client plaintext. - verify_eq!(aggregation_result[..expected_output.len()], expected_output) + verify_that!(aggregation_result.keys().collect::>(), container_eq([&default_id]))?; + verify_eq!( + aggregation_result.get(&default_id).unwrap()[..expected_output.len()], + expected_output + ) } // Encrypt and decrypt with multiple clients including invalid client proofs and a single decryptor. #[gtest] fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Result<()> { + const NUM_MAX_CLIENTS: i64 = 10; + const NUM_GOOD_CLIENTS: i64 = 10; + const NUM_BAD_CLIENTS: i64 = 5; + let default_id = String::from("default"); + let aggregation_config = + generate_aggregation_config(default_id.clone(), 16, 10, 1, NUM_MAX_CLIENTS); let public_kahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let public_ahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); // Create clients. let mut good_clients = vec![]; - for _ in 0..10 { - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + for _ in 0..NUM_GOOD_CLIENTS { + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let client = WillowV1Client { common, prng }; @@ -188,8 +233,8 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re // Create bad clients. let mut bad_clients = vec![]; - for _ in 0..5 { - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + for _ in 0..NUM_BAD_CLIENTS { + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let client = WillowV1Client { common, prng }; @@ -198,19 +243,19 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re // Create decryptor, which needs its own `common` (with same public polynomials // generated from the seeds) and `prng`. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut decryptor_state = DecryptorState::new(); let mut decryptor = WillowV1Decryptor { common, prng }; // Create server. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let server = WillowV1Server { common }; let mut server_state = ServerState::new(); // Create verifier. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let verifier = WillowV1Verifier { common }; let mut verifier_state = VerifierState::new(); @@ -226,10 +271,11 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re // Good Clients encrypt and should be included in the aggregation. let mut expected_output = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; for client in &mut good_clients { - let client_plaintext = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; + let client_input_values = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; for i in 0..expected_output.len() { - expected_output[i] += client_plaintext[i]; + expected_output[i] += client_input_values[i]; } + let client_plaintext = HashMap::from([(default_id.clone(), client_input_values)]); let client_message = client.create_client_message(&client_plaintext, &public_key).unwrap(); // The client message is split and handled by the server and verifier. let (ciphertext_contribution, decryption_request_contribution) = @@ -242,14 +288,16 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let bad_proof; { let client = &mut bad_clients[0]; - let client_plaintext = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; + let client_input_values = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; + let client_plaintext = HashMap::from([(default_id.clone(), client_input_values)]); let client_message = client.create_client_message(&client_plaintext, &public_key).unwrap(); bad_proof = client_message.proof; } // Bad Clients encrypt and should not be included in the aggregation. for i in 1..bad_clients.len() { let client = &mut bad_clients[i]; - let client_plaintext = vec![8, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7, 8]; + let client_input_values = vec![8, 7, 6, 5, 4, 3, 2, 1, 2, 3, 4, 5, 6, 7, 8]; + let client_plaintext = HashMap::from([(default_id.clone(), client_input_values)]); let mut client_message = client.create_client_message(&client_plaintext, &public_key).unwrap(); client_message.proof = bad_proof.clone(); @@ -275,28 +323,39 @@ fn encrypt_decrypt_multiple_clients_including_invalid_proofs() -> googletest::Re let aggregation_result = server.recover_aggregation_result(&server_state).unwrap(); // Check that the (padded) result matches the client plaintext. - verify_eq!(aggregation_result[..expected_output.len()], expected_output) + verify_that!(aggregation_result.keys().collect::>(), container_eq([&default_id]))?; + verify_eq!( + aggregation_result.get(&default_id).unwrap()[..expected_output.len()], + expected_output + ) } /// Encrypt and decrypt with multiple clients and multiple decryptors. /// Note: This test uses RLWE parameters for production use. #[gtest] fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { - const INPUT_LENGTH: u64 = 100_000; // 100K - const INPUT_DOMAIN: u64 = 1u64 << 32; - const MAX_NUM_CLIENTS: usize = 10_000_000; // used to generate parameters. - const MAX_NUM_DECRYPTORS: usize = 100; // used to generate parameters. + const INPUT_LENGTH: isize = 100_000; // 100K + const INPUT_DOMAIN: i64 = 1i64 << 32; + const MAX_NUM_CLIENTS: i64 = 10_000_000; // used to generate parameters. + const MAX_NUM_DECRYPTORS: i64 = 100; // used to generate parameters. const NUM_CLIENTS: usize = 3; // Actual number of clients to create. const NUM_DECRYPTORS: usize = 3; // Actual number of decryptors to create. + let default_id = String::from("default"); + let aggregation_config = generate_aggregation_config( + default_id.clone(), + INPUT_LENGTH, + INPUT_DOMAIN, + MAX_NUM_DECRYPTORS, + MAX_NUM_CLIENTS, + ); + // Create the public seeds for all clients, decryptors, and server. let public_kahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let public_ahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); // Create server. - let (kahe_config, ahe_config) = - create_shell_configs(INPUT_LENGTH, INPUT_DOMAIN, MAX_NUM_CLIENTS, MAX_NUM_DECRYPTORS) - .unwrap(); + let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config).unwrap(); let kahe = ShellKahe::new(kahe_config, &public_kahe_seed).unwrap(); let vahe = ShellVahe::new(ahe_config, &public_ahe_seed).unwrap(); let common = WillowCommon { kahe, vahe }; @@ -304,9 +363,7 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { let mut server_state = ServerState::new(); // Create verifier. - let (kahe_config, ahe_config) = - create_shell_configs(INPUT_LENGTH, INPUT_DOMAIN, MAX_NUM_CLIENTS, MAX_NUM_DECRYPTORS) - .unwrap(); + let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config).unwrap(); let kahe = ShellKahe::new(kahe_config, &public_kahe_seed).unwrap(); let vahe = ShellVahe::new(ahe_config, &public_ahe_seed).unwrap(); let common = WillowCommon { kahe, vahe }; @@ -318,9 +375,7 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { let mut decryptors = vec![]; let mut decryptor_states = vec![]; for _ in 0..NUM_DECRYPTORS { - let (kahe_config, ahe_config) = - create_shell_configs(INPUT_LENGTH, INPUT_DOMAIN, MAX_NUM_CLIENTS, MAX_NUM_DECRYPTORS) - .unwrap(); + let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config).unwrap(); let kahe = ShellKahe::new(kahe_config, &public_kahe_seed).unwrap(); let vahe = ShellVahe::new(ahe_config, &public_ahe_seed).unwrap(); let common = WillowCommon { kahe, vahe }; @@ -345,9 +400,7 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { // Create clients, and each client generates their messages. let mut expected_output = vec![0; INPUT_LENGTH as usize]; for _ in 0..NUM_CLIENTS { - let (kahe_config, ahe_config) = - create_shell_configs(INPUT_LENGTH, INPUT_DOMAIN, MAX_NUM_CLIENTS, MAX_NUM_DECRYPTORS) - .unwrap(); + let (kahe_config, ahe_config) = create_shell_configs(&aggregation_config).unwrap(); let kahe = ShellKahe::new(kahe_config, &public_kahe_seed).unwrap(); let vahe = ShellVahe::new(ahe_config, &public_ahe_seed).unwrap(); let common = WillowCommon { kahe, vahe }; @@ -355,10 +408,12 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let mut client = WillowV1Client { common, prng }; - let client_plaintext = generate_random_unsigned_vector(INPUT_LENGTH as usize, INPUT_DOMAIN); + let client_input_values = + generate_random_unsigned_vector(INPUT_LENGTH as usize, INPUT_DOMAIN as u64); for i in 0..expected_output.len() { - expected_output[i] += client_plaintext[i]; + expected_output[i] += client_input_values[i]; } + let client_plaintext = HashMap::from([(default_id.clone(), client_input_values)]); let client_message = client.create_client_message(&client_plaintext, &public_key).unwrap(); // The client message is split and handled by the server and verifier. let (ciphertext_contribution, decryption_request_contribution) = @@ -385,19 +440,28 @@ fn encrypt_decrypt_many_clients_decryptors() -> googletest::Result<()> { let aggregation_result = server.recover_aggregation_result(&server_state).unwrap(); // Check that the (padded) result matches the client plaintext. - verify_eq!(aggregation_result[..expected_output.len()], expected_output) + verify_that!(aggregation_result.keys().collect::>(), container_eq([&default_id]))?; + verify_eq!( + aggregation_result.get(&default_id).unwrap()[..expected_output.len()], + expected_output + ) } // Encrypt and decrypt with multiple clients and multiple decryptors, but no dropout. #[gtest] fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { + const NUM_CLIENTS: i64 = 10; + const NUM_DECRYPTORS: i64 = 10; + let default_id = String::from("default"); + let aggregation_config = + generate_aggregation_config(default_id.clone(), 16, 10, NUM_DECRYPTORS, NUM_CLIENTS); let public_kahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let public_ahe_seed = SingleThreadHkdfPrng::generate_seed().unwrap(); // Create clients. let mut clients = vec![]; - for _ in 0..10 { - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + for _ in 0..NUM_CLIENTS { + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let client = WillowV1Client { common, prng }; @@ -408,8 +472,8 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { // generated from the seeds) and `prng`. let mut decryptor_states = vec![]; let mut decryptors = vec![]; - for _ in 0..10 { - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + for _ in 0..NUM_DECRYPTORS { + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let seed = SingleThreadHkdfPrng::generate_seed().unwrap(); let prng = SingleThreadHkdfPrng::create(&seed).unwrap(); let decryptor_state = DecryptorState::new(); @@ -419,12 +483,12 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { } // Create server. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let server = WillowV1Server { common }; let mut server_state = ServerState::new(); // Create verifier. - let common = create_willow_common(&public_kahe_seed, &public_ahe_seed); + let common = create_willow_common(&aggregation_config, &public_kahe_seed, &public_ahe_seed); let verifier = WillowV1Verifier { common }; let mut verifier_state = VerifierState::new(); @@ -442,10 +506,11 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { // Clients encrypt. let mut expected_output = vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; for client in &mut clients { - let client_plaintext = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; + let client_input_values = vec![1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1]; for i in 0..expected_output.len() { - expected_output[i] += client_plaintext[i]; + expected_output[i] += client_input_values[i]; } + let client_plaintext = HashMap::from([(default_id.clone(), client_input_values)]); let client_message = client.create_client_message(&client_plaintext, &public_key).unwrap(); // The client message is split and handled by the server and verifier. let (ciphertext_contribution, decryption_request_contribution) = @@ -470,5 +535,9 @@ fn encrypt_decrypt_no_dropout() -> googletest::Result<()> { let aggregation_result = server.recover_aggregation_result(&server_state).unwrap(); // Check that the (padded) result matches the client plaintext. - verify_eq!(aggregation_result[..expected_output.len()], expected_output) + verify_that!(aggregation_result.keys().collect::>(), container_eq([&default_id]))?; + verify_eq!( + aggregation_result.get(&default_id).unwrap()[..expected_output.len()], + expected_output + ) }