Skip to content

Commit 2325249

Browse files
stanischikncopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 824752645
1 parent 3bcd829 commit 2325249

28 files changed

+1567
-924
lines changed

shell_wrapper/kahe.cc

Lines changed: 100 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ absl::StatusOr<RnsPolynomial> GenerateSecretKey(
157157
}
158158

159159
namespace internal {
160+
160161
absl::StatusOr<RnsPolynomial> EncryptPolynomial(
161162
const RnsPolynomial& plaintext, const RnsPolynomial& secret_key,
162163
RnsInt plaintext_modulus_rns, int log_n, const RnsPolynomial& a,
@@ -196,51 +197,31 @@ absl::StatusOr<RnsPolynomial> DecryptPolynomial(
196197
return p;
197198
}
198199

199-
std::vector<std::vector<BigInteger>> PackMessagesRaw(const Integer* messages,
200-
int num_messages,
201-
Integer packing_base,
202-
int num_packing,
203-
int num_coeffs) {
204-
// NOTE: temporary implementation that copies the input. We can avoid copying
205-
// by modifying the original PackMessages function to work with pointers
206-
// directly.
207-
std::vector<Integer> messages_copy;
208-
messages_copy.reserve(num_messages);
209-
for (int i = 0; i < num_messages; ++i) {
210-
messages_copy.push_back(messages[i]);
211-
}
212-
return rlwe::PackMessages<Integer, BigInteger>(messages_copy, packing_base,
213-
num_packing, num_coeffs);
214-
}
215-
216-
int UnpackMessagesRaw(
217-
const std::vector<std::vector<BigInteger>>& packed_messages,
218-
uint64_t packing_base, int num_packing, int output_values_length,
219-
uint64_t* output_values) {
220-
std::vector<uint64_t> unpacked_messages =
221-
rlwe::UnpackMessages(packed_messages, packing_base, num_packing);
222-
223-
auto count = std::min(static_cast<size_t>(output_values_length),
224-
unpacked_messages.size());
225-
std::copy_n(unpacked_messages.begin(), count, output_values);
226-
return count;
227-
}
228-
229200
} // namespace internal
230201

231202
absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector(
232-
std::vector<std::vector<BigInteger>>& packed_messages,
203+
const std::vector<BigInteger>& packed_values,
233204
const RnsPolynomial& secret_key, const KahePublicParameters& params,
234205
Prng* prng) {
235-
std::vector<RnsPolynomial> ciphertexts;
236-
237-
if (packed_messages.size() > params.public_polynomials.size()) {
238-
return absl::InvalidArgumentError("Input too long");
206+
std::vector<std::vector<BigInteger>> plaintexts;
207+
plaintexts.reserve(params.public_polynomials.size());
208+
int num_coeffs = 1 << params.context->LogN();
209+
210+
for (size_t i = 0; i < packed_values.size(); i += num_coeffs) {
211+
size_t chunk_end = std::min<size_t>(packed_values.size(), i + num_coeffs);
212+
plaintexts.emplace_back(packed_values.begin() + i,
213+
packed_values.begin() + chunk_end);
214+
}
215+
if (plaintexts.size() > params.public_polynomials.size()) {
216+
return absl::InvalidArgumentError("input too long.");
239217
}
240218

241-
for (int i = 0; i < packed_messages.size(); ++i) {
242-
const auto& packed_message = packed_messages[i];
219+
std::vector<RnsPolynomial> ciphertexts;
220+
for (int i = 0; i < plaintexts.size(); ++i) {
221+
const auto& packed_message = plaintexts[i];
243222
const RnsPolynomial& a = params.public_polynomials[i];
223+
// EncodeBgv will pad `packed_message` with zeros to the length of a
224+
// polynomial coefficient vector.
244225
SECAGG_ASSIGN_OR_RETURN(
245226
RnsPolynomial plaintext,
246227
params.encoder.EncodeBgv<BigInteger>(
@@ -257,10 +238,15 @@ absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector(
257238
return ciphertexts;
258239
}
259240

260-
absl::StatusOr<std::vector<std::vector<BigInteger>>> DecodeAndDecryptVector(
241+
absl::StatusOr<std::vector<BigInteger>> DecodeAndDecryptVector(
261242
absl::Span<const RnsPolynomial> ciphertexts,
262243
const RnsPolynomial& secret_key, const KahePublicParameters& params) {
263-
std::vector<std::vector<BigInteger>> all_packed_messages;
244+
if (ciphertexts.size() > params.public_polynomials.size()) {
245+
return absl::InvalidArgumentError(
246+
"The size of `ciphertexts` cannot be larger than the size of public "
247+
"polynomials.");
248+
}
249+
std::vector<BigInteger> all_packed_messages;
264250
for (int i = 0; i < ciphertexts.size(); ++i) {
265251
const auto& ciphertext = ciphertexts[i];
266252
const RnsPolynomial& a = params.public_polynomials[i];
@@ -272,7 +258,8 @@ absl::StatusOr<std::vector<std::vector<BigInteger>>> DecodeAndDecryptVector(
272258
params.encoder.DecodeBgv<BigInteger>(
273259
std::move(plaintext), params.plaintext_modulus, params.moduli,
274260
params.modulus_hats, params.modulus_hats_invs));
275-
all_packed_messages.push_back(std::move(packed_messages));
261+
all_packed_messages.insert(all_packed_messages.end(),
262+
packed_messages.begin(), packed_messages.end());
276263
}
277264
return all_packed_messages;
278265
}
@@ -326,27 +313,79 @@ FfiStatus GenerateSecretKeyWrapper(const KahePublicParametersWrapper& params,
326313
return MakeFfiStatus();
327314
}
328315

329-
FfiStatus Encrypt(rust::Slice<const uint64_t> input_values,
330-
uint64_t packing_base, uint64_t num_packing,
316+
FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
317+
uint64_t packing_base, uint64_t packing_dimension,
318+
uint64_t num_packed_values,
319+
BigIntVectorWrapper* packed_values) {
320+
// Validate the wrappers.
321+
if (packed_values == nullptr) {
322+
return MakeFfiStatus(absl::InvalidArgumentError(
323+
secure_aggregation::kNullPointerErrorMessage));
324+
}
325+
326+
// Allocate the vector for output packed values if needed.
327+
if (packed_values->ptr == nullptr) {
328+
packed_values->ptr =
329+
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
330+
}
331+
auto curr_packed_values =
332+
rlwe::PackMessagesFlat<secure_aggregation::Integer,
333+
secure_aggregation::BigInteger>(
334+
absl::MakeSpan(messages.data(), messages.size()), packing_base,
335+
packing_dimension);
336+
if (curr_packed_values.size() > num_packed_values) {
337+
return MakeFfiStatus(absl::InvalidArgumentError(
338+
"The number of packed values exceeds `num_packed_values`."));
339+
}
340+
// Pad with zeros if needed.
341+
curr_packed_values.resize(num_packed_values, 0);
342+
// Append the packed values to the end of the output vector.
343+
packed_values->ptr->insert(packed_values->ptr->end(),
344+
curr_packed_values.begin(),
345+
curr_packed_values.end());
346+
return MakeFfiStatus();
347+
}
348+
349+
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
350+
uint64_t num_packed_values,
351+
BigIntVectorWrapper& packed_values,
352+
rust::Vec<uint64_t>& out) {
353+
// Validate the wrappers.
354+
if (packed_values.ptr == nullptr) {
355+
return MakeFfiStatus(absl::InvalidArgumentError(
356+
secure_aggregation::kNullPointerErrorMessage));
357+
}
358+
if (packed_values.ptr->size() < num_packed_values) {
359+
return MakeFfiStatus(
360+
absl::InvalidArgumentError("insufficient number of packed values."));
361+
}
362+
std::vector<uint64_t> unpacked_messages =
363+
rlwe::UnpackMessagesFlat<secure_aggregation::Integer,
364+
secure_aggregation::BigInteger>(
365+
absl::MakeSpan(*packed_values.ptr).subspan(0, num_packed_values),
366+
packing_base, packing_dimension);
367+
packed_values.ptr->erase(packed_values.ptr->begin(),
368+
packed_values.ptr->begin() + num_packed_values);
369+
for (auto& val : unpacked_messages) {
370+
out.push_back(val);
371+
}
372+
return MakeFfiStatus();
373+
}
374+
375+
FfiStatus Encrypt(const BigIntVectorWrapper& packed_values,
331376
const RnsPolynomialWrapper& secret_key,
332377
const KahePublicParametersWrapper& params,
333378
SingleThreadHkdfWrapper* prng, RnsPolynomialVecWrapper* out) {
334379
// Validate the wrappers.
335-
if (secret_key.ptr == nullptr || params.ptr == nullptr || prng == nullptr ||
336-
prng->ptr == nullptr || out == nullptr) {
380+
if (packed_values.ptr == nullptr || secret_key.ptr == nullptr ||
381+
params.ptr == nullptr || prng == nullptr || prng->ptr == nullptr ||
382+
out == nullptr) {
337383
return MakeFfiStatus(absl::InvalidArgumentError(
338384
secure_aggregation::kNullPointerErrorMessage));
339385
}
340386

341-
// Packing parameters must be valid, e.g. checked on the Rust side.
342-
int num_coeffs = 1 << params.ptr->context->LogN();
343-
std::vector<std::vector<secure_aggregation::BigInteger>> packed_messages =
344-
secure_aggregation::internal::PackMessagesRaw(
345-
input_values.data(), input_values.size(), packing_base, num_packing,
346-
num_coeffs);
347-
348387
auto ciphertext_vec = secure_aggregation::EncodeAndEncryptVector(
349-
packed_messages, *secret_key.ptr, *params.ptr, prng->ptr.get());
388+
*packed_values.ptr, *secret_key.ptr, *params.ptr, prng->ptr.get());
350389

351390
if (!ciphertext_vec.ok()) {
352391
return MakeFfiStatus(ciphertext_vec.status());
@@ -357,14 +396,13 @@ FfiStatus Encrypt(rust::Slice<const uint64_t> input_values,
357396
return MakeFfiStatus();
358397
}
359398

360-
FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing,
361-
const RnsPolynomialVecWrapper& ciphertexts,
399+
FfiStatus Decrypt(const RnsPolynomialVecWrapper& ciphertexts,
362400
const RnsPolynomialWrapper& secret_key,
363401
const KahePublicParametersWrapper& params,
364-
rust::Slice<uint64_t> output_values, uint64_t* n_written) {
402+
BigIntVectorWrapper* output_values) {
365403
// Validate the wrappers.
366404
if (secret_key.ptr == nullptr || params.ptr == nullptr ||
367-
ciphertexts.ptr == nullptr || n_written == nullptr) {
405+
ciphertexts.ptr == nullptr || output_values == nullptr) {
368406
return MakeFfiStatus(absl::InvalidArgumentError(
369407
secure_aggregation::kNullPointerErrorMessage));
370408
}
@@ -378,15 +416,13 @@ FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing,
378416
}
379417
}
380418

381-
auto messages = secure_aggregation::DecodeAndDecryptVector(
419+
auto decrypted_values = secure_aggregation::DecodeAndDecryptVector(
382420
*ciphertexts.ptr, *secret_key.ptr, *params.ptr);
383-
if (!messages.ok()) {
384-
return MakeFfiStatus(messages.status());
421+
if (!decrypted_values.ok()) {
422+
return MakeFfiStatus(decrypted_values.status());
385423
}
386-
387-
*n_written = secure_aggregation::internal::UnpackMessagesRaw(
388-
messages.value(), packing_base, num_packing, output_values.size(),
389-
output_values.data());
390-
424+
output_values->ptr =
425+
std::make_unique<std::vector<secure_aggregation::BigInteger>>(
426+
std::move(decrypted_values.value()));
391427
return MakeFfiStatus();
392428
}

shell_wrapper/kahe.h

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
namespace secure_aggregation {
3636
// Forward-declare types for use by the cxx-generated `kahe.rs.h`.
3737
struct KahePublicParameters;
38+
struct BigIntVector;
3839
} // namespace secure_aggregation
3940

4041
#include "shell_wrapper/kahe.rs.h"
@@ -93,36 +94,18 @@ absl::StatusOr<RnsPolynomial> DecryptPolynomial(
9394
const RnsPolynomial& a,
9495
absl::Span<const rlwe::PrimeModulus<ModularInt>* const> moduli);
9596

96-
// Packs messages taken from a raw pointer.
97-
// Expects packing_base > 1, num_packing > 0, num_coeffs > 0,
98-
// packing_base^num_packing < std::numeric_limits<BigInteger>::max().
99-
std::vector<std::vector<BigInteger>> PackMessagesRaw(const uint64_t* messages,
100-
int num_messages,
101-
uint64_t packing_base,
102-
int num_packing,
103-
int num_coeffs);
104-
105-
// Unpacks messages into a buffer `output_values` of length
106-
// at least `output_values_length`. Returns the elements written to the buffer
107-
// (0 if it didn't write anything).
108-
// Expects packing_base > 1, num_packing > 0, num_coeffs > 0,
109-
// packing_base^num_packing < std::numeric_limits<BigInteger>::max().
110-
int UnpackMessagesRaw(
111-
const std::vector<std::vector<BigInteger>>& packed_messages,
112-
Integer packing_base, int num_packing, int output_values_length,
113-
Integer* output_values);
114-
11597
} // namespace internal
11698

117-
// Encrypts a vector of packed messages, where each coordinate is a vector of
118-
// integers that will be encoded into a single polynomial.
99+
// Encrypts a vector of packed messages, where the packed messages are first
100+
// encoded into plaintext polynomials and then encrypted.
119101
absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector(
120-
std::vector<std::vector<BigInteger>>& packed_messages,
102+
const std::vector<BigInteger>& packed_values,
121103
const RnsPolynomial& secret_key, const KahePublicParameters& params,
122104
Prng* prng);
123105

124-
// Decrypts a vector of ciphertexts.
125-
absl::StatusOr<std::vector<std::vector<BigInteger>>> DecodeAndDecryptVector(
106+
// Decrypts a vector of ciphertexts, and returns the concatenated vector of
107+
// decrypted messages.
108+
absl::StatusOr<std::vector<BigInteger>> DecodeAndDecryptVector(
126109
absl::Span<const RnsPolynomial> ciphertexts,
127110
const RnsPolynomial& secret_key, const KahePublicParameters& params);
128111

@@ -158,32 +141,47 @@ FfiStatus GenerateSecretKeyWrapper(const KahePublicParametersWrapper& params,
158141
SingleThreadHkdfWrapper* prng,
159142
RnsPolynomialWrapper* out);
160143

161-
// Packs, encodes and encrypts the messages contained in the `input_values`
162-
// buffer. `packing_base` and `num_packing` are the parameters for packing: the
163-
// encoder takes large modular integers obtained by combining `num_packing`
164-
// smaller uint64_t values, each of which is less than `packing_base`. If
165-
// successful, returns OK and sets *out to a vector of ciphertexts, each of
166-
// which is a polynomial.
167-
FfiStatus Encrypt(rust::Slice<const uint64_t> input_values,
168-
uint64_t packing_base, uint64_t num_packing,
144+
// Packs `messages` into a vector of BigIntegers using base `packing_base`
145+
// encoding, where the packed values are appended to `packed_values`.
146+
// Expects `packed_values` to be a valid pointer but the underlying vector
147+
// may be unallocated, and expects packing_base > 1, packing_dimension > 0,
148+
// num_coeffs > 0, packing_base^packing_dimension <
149+
// std::numeric_limits<BigInteger>::max().
150+
// Note that `messages` is effectively padded with zeros to the nearest multiple
151+
// of `packing_dimension` before packing.
152+
FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
153+
uint64_t packing_base, uint64_t packing_dimension,
154+
uint64_t num_packed_values,
155+
BigIntVectorWrapper* packed_values);
156+
157+
// Unpacks messages stored at `packed_values[0..num_packed_values]` and appends
158+
// them to `out`, and removes these packed values from `packed_values`.
159+
// Expects `packed_values.ptr` to be a valid pointer to the vector of packed
160+
// values, and expects packing_base > 1, packing_dimension > 0,
161+
// num_packed_values > 0, packing_base^packing_dimension <
162+
// std::numeric_limits<BigInteger>::max().
163+
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
164+
uint64_t num_packed_values,
165+
BigIntVectorWrapper& packed_values,
166+
rust::Vec<uint64_t>& out);
167+
168+
// Encrypts the messages contained in `packed_values`. If successful, returns OK
169+
// and sets *out to a vector of ciphertext polynomials.
170+
// Expects `out` to be a valid pointer but the underlying vector may be
171+
// unallocated.
172+
FfiStatus Encrypt(const BigIntVectorWrapper& packed_values,
169173
const RnsPolynomialWrapper& secret_key,
170174
const KahePublicParametersWrapper& params,
171175
SingleThreadHkdfWrapper* prng, RnsPolynomialVecWrapper* out);
172176

173-
// Decrypts, decodes and unpacks `ciphertexts` into the `output_values` buffer.
174-
// Returns a status, and writes the number of outputs written to `n_written`.
175-
// Each decoded message is a large modular integer, which is then split into
176-
// `num_packing` smaller uint64_t values by decomposing into base
177-
// `packing_base`. Note: Internally, decryption yields num_coeffs * num_packing
178-
// * ciphertexts.size() Integers exactly (decryption does padding), with
179-
// num_coeffs = 1 << params->ptr->context->LogN(). But this function can
180-
// take a smaller `output_values_length` to fill a smaller buffer, e.g. to
181-
// remove padding if the plaintext vector length is known.
182-
FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing,
183-
const RnsPolynomialVecWrapper& ciphertexts,
177+
// Decrypts `ciphertexts` into a vector written to `output_values` buffer, and
178+
// returns a status.
179+
// Expects `output_values` to be a valid pointer but the underlying vector may
180+
// be unallocated.
181+
FfiStatus Decrypt(const RnsPolynomialVecWrapper& ciphertexts,
184182
const RnsPolynomialWrapper& secret_key,
185183
const KahePublicParametersWrapper& params,
186-
rust::Slice<uint64_t> output_values, uint64_t* n_written);
184+
BigIntVectorWrapper* output_values);
187185

188186
} // extern "C"
189187

0 commit comments

Comments
 (0)