@@ -157,6 +157,7 @@ absl::StatusOr<RnsPolynomial> GenerateSecretKey(
157157}
158158
159159namespace internal {
160+
160161absl::StatusOr<RnsPolynomial> EncryptPolynomial (
161162 const RnsPolynomial& plaintext, const RnsPolynomial& secret_key,
162163 RnsInt plaintext_modulus_rns, int log_n, const RnsPolynomial& a,
@@ -196,51 +197,31 @@ absl::StatusOr<RnsPolynomial> DecryptPolynomial(
196197 return p;
197198}
198199
199- std::vector<std::vector<BigInteger>> PackMessagesRaw (const Integer* messages,
200- int num_messages,
201- Integer packing_base,
202- int num_packing,
203- int num_coeffs) {
204- // NOTE: temporary implementation that copies the input. We can avoid copying
205- // by modifying the original PackMessages function to work with pointers
206- // directly.
207- std::vector<Integer> messages_copy;
208- messages_copy.reserve (num_messages);
209- for (int i = 0 ; i < num_messages; ++i) {
210- messages_copy.push_back (messages[i]);
211- }
212- return rlwe::PackMessages<Integer, BigInteger>(messages_copy, packing_base,
213- num_packing, num_coeffs);
214- }
215-
216- int UnpackMessagesRaw (
217- const std::vector<std::vector<BigInteger>>& packed_messages,
218- uint64_t packing_base, int num_packing, int output_values_length,
219- uint64_t * output_values) {
220- std::vector<uint64_t > unpacked_messages =
221- rlwe::UnpackMessages (packed_messages, packing_base, num_packing);
222-
223- auto count = std::min (static_cast <size_t >(output_values_length),
224- unpacked_messages.size ());
225- std::copy_n (unpacked_messages.begin (), count, output_values);
226- return count;
227- }
228-
229200} // namespace internal
230201
231202absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector (
232- std::vector<std::vector< BigInteger>>& packed_messages ,
203+ const std::vector<BigInteger>& packed_values ,
233204 const RnsPolynomial& secret_key, const KahePublicParameters& params,
234205 Prng* prng) {
235- std::vector<RnsPolynomial> ciphertexts;
236-
237- if (packed_messages.size () > params.public_polynomials .size ()) {
238- return absl::InvalidArgumentError (" Input too long" );
206+ std::vector<std::vector<BigInteger>> plaintexts;
207+ plaintexts.reserve (params.public_polynomials .size ());
208+ int num_coeffs = 1 << params.context ->LogN ();
209+
210+ for (size_t i = 0 ; i < packed_values.size (); i += num_coeffs) {
211+ size_t chunk_end = std::min<size_t >(packed_values.size (), i + num_coeffs);
212+ plaintexts.emplace_back (packed_values.begin () + i,
213+ packed_values.begin () + chunk_end);
214+ }
215+ if (plaintexts.size () > params.public_polynomials .size ()) {
216+ return absl::InvalidArgumentError (" input too long." );
239217 }
240218
241- for (int i = 0 ; i < packed_messages.size (); ++i) {
242- const auto & packed_message = packed_messages[i];
219+ std::vector<RnsPolynomial> ciphertexts;
220+ for (int i = 0 ; i < plaintexts.size (); ++i) {
221+ const auto & packed_message = plaintexts[i];
243222 const RnsPolynomial& a = params.public_polynomials [i];
223+ // EncodeBgv will pad `packed_message` with zeros to the length of a
224+ // polynomial coefficient vector.
244225 SECAGG_ASSIGN_OR_RETURN (
245226 RnsPolynomial plaintext,
246227 params.encoder .EncodeBgv <BigInteger>(
@@ -257,10 +238,15 @@ absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector(
257238 return ciphertexts;
258239}
259240
260- absl::StatusOr<std::vector<std::vector< BigInteger> >> DecodeAndDecryptVector (
241+ absl::StatusOr<std::vector<BigInteger>> DecodeAndDecryptVector (
261242 absl::Span<const RnsPolynomial> ciphertexts,
262243 const RnsPolynomial& secret_key, const KahePublicParameters& params) {
263- std::vector<std::vector<BigInteger>> all_packed_messages;
244+ if (ciphertexts.size () > params.public_polynomials .size ()) {
245+ return absl::InvalidArgumentError (
246+ " The size of `ciphertexts` cannot be larger than the size of public "
247+ " polynomials." );
248+ }
249+ std::vector<BigInteger> all_packed_messages;
264250 for (int i = 0 ; i < ciphertexts.size (); ++i) {
265251 const auto & ciphertext = ciphertexts[i];
266252 const RnsPolynomial& a = params.public_polynomials [i];
@@ -272,7 +258,8 @@ absl::StatusOr<std::vector<std::vector<BigInteger>>> DecodeAndDecryptVector(
272258 params.encoder .DecodeBgv <BigInteger>(
273259 std::move (plaintext), params.plaintext_modulus , params.moduli ,
274260 params.modulus_hats , params.modulus_hats_invs ));
275- all_packed_messages.push_back (std::move (packed_messages));
261+ all_packed_messages.insert (all_packed_messages.end (),
262+ packed_messages.begin (), packed_messages.end ());
276263 }
277264 return all_packed_messages;
278265}
@@ -326,27 +313,79 @@ FfiStatus GenerateSecretKeyWrapper(const KahePublicParametersWrapper& params,
326313 return MakeFfiStatus ();
327314}
328315
329- FfiStatus Encrypt (rust::Slice<const uint64_t > input_values,
330- uint64_t packing_base, uint64_t num_packing,
316+ FfiStatus PackMessagesRaw (rust::Slice<const uint64_t > messages,
317+ uint64_t packing_base, uint64_t packing_dimension,
318+ uint64_t num_packed_values,
319+ BigIntVectorWrapper* packed_values) {
320+ // Validate the wrappers.
321+ if (packed_values == nullptr ) {
322+ return MakeFfiStatus (absl::InvalidArgumentError (
323+ secure_aggregation::kNullPointerErrorMessage ));
324+ }
325+
326+ // Allocate the vector for output packed values if needed.
327+ if (packed_values->ptr == nullptr ) {
328+ packed_values->ptr =
329+ std::make_unique<std::vector<secure_aggregation::BigInteger>>();
330+ }
331+ auto curr_packed_values =
332+ rlwe::PackMessagesFlat<secure_aggregation::Integer,
333+ secure_aggregation::BigInteger>(
334+ absl::MakeSpan (messages.data (), messages.size ()), packing_base,
335+ packing_dimension);
336+ if (curr_packed_values.size () > num_packed_values) {
337+ return MakeFfiStatus (absl::InvalidArgumentError (
338+ " The number of packed values exceeds `num_packed_values`." ));
339+ }
340+ // Pad with zeros if needed.
341+ curr_packed_values.resize (num_packed_values, 0 );
342+ // Append the packed values to the end of the output vector.
343+ packed_values->ptr ->insert (packed_values->ptr ->end (),
344+ curr_packed_values.begin (),
345+ curr_packed_values.end ());
346+ return MakeFfiStatus ();
347+ }
348+
349+ FfiStatus UnpackMessagesRaw (uint64_t packing_base, uint64_t packing_dimension,
350+ uint64_t num_packed_values,
351+ BigIntVectorWrapper& packed_values,
352+ rust::Vec<uint64_t >& out) {
353+ // Validate the wrappers.
354+ if (packed_values.ptr == nullptr ) {
355+ return MakeFfiStatus (absl::InvalidArgumentError (
356+ secure_aggregation::kNullPointerErrorMessage ));
357+ }
358+ if (packed_values.ptr ->size () < num_packed_values) {
359+ return MakeFfiStatus (
360+ absl::InvalidArgumentError (" insufficient number of packed values." ));
361+ }
362+ std::vector<uint64_t > unpacked_messages =
363+ rlwe::UnpackMessagesFlat<secure_aggregation::Integer,
364+ secure_aggregation::BigInteger>(
365+ absl::MakeSpan (*packed_values.ptr ).subspan (0 , num_packed_values),
366+ packing_base, packing_dimension);
367+ packed_values.ptr ->erase (packed_values.ptr ->begin (),
368+ packed_values.ptr ->begin () + num_packed_values);
369+ for (auto & val : unpacked_messages) {
370+ out.push_back (val);
371+ }
372+ return MakeFfiStatus ();
373+ }
374+
375+ FfiStatus Encrypt (const BigIntVectorWrapper& packed_values,
331376 const RnsPolynomialWrapper& secret_key,
332377 const KahePublicParametersWrapper& params,
333378 SingleThreadHkdfWrapper* prng, RnsPolynomialVecWrapper* out) {
334379 // Validate the wrappers.
335- if (secret_key.ptr == nullptr || params.ptr == nullptr || prng == nullptr ||
336- prng->ptr == nullptr || out == nullptr ) {
380+ if (packed_values.ptr == nullptr || secret_key.ptr == nullptr ||
381+ params.ptr == nullptr || prng == nullptr || prng->ptr == nullptr ||
382+ out == nullptr ) {
337383 return MakeFfiStatus (absl::InvalidArgumentError (
338384 secure_aggregation::kNullPointerErrorMessage ));
339385 }
340386
341- // Packing parameters must be valid, e.g. checked on the Rust side.
342- int num_coeffs = 1 << params.ptr ->context ->LogN ();
343- std::vector<std::vector<secure_aggregation::BigInteger>> packed_messages =
344- secure_aggregation::internal::PackMessagesRaw (
345- input_values.data (), input_values.size (), packing_base, num_packing,
346- num_coeffs);
347-
348387 auto ciphertext_vec = secure_aggregation::EncodeAndEncryptVector (
349- packed_messages , *secret_key.ptr , *params.ptr , prng->ptr .get ());
388+ *packed_values. ptr , *secret_key.ptr , *params.ptr , prng->ptr .get ());
350389
351390 if (!ciphertext_vec.ok ()) {
352391 return MakeFfiStatus (ciphertext_vec.status ());
@@ -357,14 +396,13 @@ FfiStatus Encrypt(rust::Slice<const uint64_t> input_values,
357396 return MakeFfiStatus ();
358397}
359398
360- FfiStatus Decrypt (uint64_t packing_base, uint64_t num_packing,
361- const RnsPolynomialVecWrapper& ciphertexts,
399+ FfiStatus Decrypt (const RnsPolynomialVecWrapper& ciphertexts,
362400 const RnsPolynomialWrapper& secret_key,
363401 const KahePublicParametersWrapper& params,
364- rust::Slice< uint64_t > output_values, uint64_t * n_written ) {
402+ BigIntVectorWrapper* output_values ) {
365403 // Validate the wrappers.
366404 if (secret_key.ptr == nullptr || params.ptr == nullptr ||
367- ciphertexts.ptr == nullptr || n_written == nullptr ) {
405+ ciphertexts.ptr == nullptr || output_values == nullptr ) {
368406 return MakeFfiStatus (absl::InvalidArgumentError (
369407 secure_aggregation::kNullPointerErrorMessage ));
370408 }
@@ -378,15 +416,13 @@ FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing,
378416 }
379417 }
380418
381- auto messages = secure_aggregation::DecodeAndDecryptVector (
419+ auto decrypted_values = secure_aggregation::DecodeAndDecryptVector (
382420 *ciphertexts.ptr , *secret_key.ptr , *params.ptr );
383- if (!messages .ok ()) {
384- return MakeFfiStatus (messages .status ());
421+ if (!decrypted_values .ok ()) {
422+ return MakeFfiStatus (decrypted_values .status ());
385423 }
386-
387- *n_written = secure_aggregation::internal::UnpackMessagesRaw (
388- messages.value (), packing_base, num_packing, output_values.size (),
389- output_values.data ());
390-
424+ output_values->ptr =
425+ std::make_unique<std::vector<secure_aggregation::BigInteger>>(
426+ std::move (decrypted_values.value ()));
391427 return MakeFfiStatus ();
392428}
0 commit comments