Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 100 additions & 64 deletions shell_wrapper/kahe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ absl::StatusOr<RnsPolynomial> GenerateSecretKey(
}

namespace internal {

absl::StatusOr<RnsPolynomial> EncryptPolynomial(
const RnsPolynomial& plaintext, const RnsPolynomial& secret_key,
RnsInt plaintext_modulus_rns, int log_n, const RnsPolynomial& a,
Expand Down Expand Up @@ -196,51 +197,31 @@ absl::StatusOr<RnsPolynomial> DecryptPolynomial(
return p;
}

std::vector<std::vector<BigInteger>> PackMessagesRaw(const Integer* messages,
int num_messages,
Integer packing_base,
int num_packing,
int num_coeffs) {
// NOTE: temporary implementation that copies the input. We can avoid copying
// by modifying the original PackMessages function to work with pointers
// directly.
std::vector<Integer> messages_copy;
messages_copy.reserve(num_messages);
for (int i = 0; i < num_messages; ++i) {
messages_copy.push_back(messages[i]);
}
return rlwe::PackMessages<Integer, BigInteger>(messages_copy, packing_base,
num_packing, num_coeffs);
}

int UnpackMessagesRaw(
const std::vector<std::vector<BigInteger>>& packed_messages,
uint64_t packing_base, int num_packing, int output_values_length,
uint64_t* output_values) {
std::vector<uint64_t> unpacked_messages =
rlwe::UnpackMessages(packed_messages, packing_base, num_packing);

auto count = std::min(static_cast<size_t>(output_values_length),
unpacked_messages.size());
std::copy_n(unpacked_messages.begin(), count, output_values);
return count;
}

} // namespace internal

absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector(
std::vector<std::vector<BigInteger>>& packed_messages,
const std::vector<BigInteger>& packed_values,
const RnsPolynomial& secret_key, const KahePublicParameters& params,
Prng* prng) {
std::vector<RnsPolynomial> ciphertexts;

if (packed_messages.size() > params.public_polynomials.size()) {
return absl::InvalidArgumentError("Input too long");
std::vector<std::vector<BigInteger>> plaintexts;
plaintexts.reserve(params.public_polynomials.size());
int num_coeffs = 1 << params.context->LogN();

for (size_t i = 0; i < packed_values.size(); i += num_coeffs) {
size_t chunk_end = std::min<size_t>(packed_values.size(), i + num_coeffs);
plaintexts.emplace_back(packed_values.begin() + i,
packed_values.begin() + chunk_end);
}
if (plaintexts.size() > params.public_polynomials.size()) {
return absl::InvalidArgumentError("input too long.");
}

for (int i = 0; i < packed_messages.size(); ++i) {
const auto& packed_message = packed_messages[i];
std::vector<RnsPolynomial> ciphertexts;
for (int i = 0; i < plaintexts.size(); ++i) {
const auto& packed_message = plaintexts[i];
const RnsPolynomial& a = params.public_polynomials[i];
// EncodeBgv will pad `packed_message` with zeros to the length of a
// polynomial coefficient vector.
SECAGG_ASSIGN_OR_RETURN(
RnsPolynomial plaintext,
params.encoder.EncodeBgv<BigInteger>(
Expand All @@ -257,10 +238,15 @@ absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector(
return ciphertexts;
}

absl::StatusOr<std::vector<std::vector<BigInteger>>> DecodeAndDecryptVector(
absl::StatusOr<std::vector<BigInteger>> DecodeAndDecryptVector(
absl::Span<const RnsPolynomial> ciphertexts,
const RnsPolynomial& secret_key, const KahePublicParameters& params) {
std::vector<std::vector<BigInteger>> all_packed_messages;
if (ciphertexts.size() > params.public_polynomials.size()) {
return absl::InvalidArgumentError(
"The size of `ciphertexts` cannot be larger than the size of public "
"polynomials.");
}
std::vector<BigInteger> all_packed_messages;
for (int i = 0; i < ciphertexts.size(); ++i) {
const auto& ciphertext = ciphertexts[i];
const RnsPolynomial& a = params.public_polynomials[i];
Expand All @@ -272,7 +258,8 @@ absl::StatusOr<std::vector<std::vector<BigInteger>>> DecodeAndDecryptVector(
params.encoder.DecodeBgv<BigInteger>(
std::move(plaintext), params.plaintext_modulus, params.moduli,
params.modulus_hats, params.modulus_hats_invs));
all_packed_messages.push_back(std::move(packed_messages));
all_packed_messages.insert(all_packed_messages.end(),
packed_messages.begin(), packed_messages.end());
}
return all_packed_messages;
}
Expand Down Expand Up @@ -326,27 +313,79 @@ FfiStatus GenerateSecretKeyWrapper(const KahePublicParametersWrapper& params,
return MakeFfiStatus();
}

FfiStatus Encrypt(rust::Slice<const uint64_t> input_values,
uint64_t packing_base, uint64_t num_packing,
FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
uint64_t packing_base, uint64_t packing_dimension,
uint64_t num_packed_values,
BigIntVectorWrapper* packed_values) {
// Validate the wrappers.
if (packed_values == nullptr) {
return MakeFfiStatus(absl::InvalidArgumentError(
secure_aggregation::kNullPointerErrorMessage));
}

// Allocate the vector for output packed values if needed.
if (packed_values->ptr == nullptr) {
packed_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>();
}
auto curr_packed_values =
rlwe::PackMessagesFlat<secure_aggregation::Integer,
secure_aggregation::BigInteger>(
absl::MakeSpan(messages.data(), messages.size()), packing_base,
packing_dimension);
if (curr_packed_values.size() > num_packed_values) {
return MakeFfiStatus(absl::InvalidArgumentError(
"The number of packed values exceeds `num_packed_values`."));
}
// Pad with zeros if needed.
curr_packed_values.resize(num_packed_values, 0);
// Append the packed values to the end of the output vector.
packed_values->ptr->insert(packed_values->ptr->end(),
curr_packed_values.begin(),
curr_packed_values.end());
return MakeFfiStatus();
}

FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
uint64_t num_packed_values,
BigIntVectorWrapper& packed_values,
rust::Vec<uint64_t>& out) {
// Validate the wrappers.
if (packed_values.ptr == nullptr) {
return MakeFfiStatus(absl::InvalidArgumentError(
secure_aggregation::kNullPointerErrorMessage));
}
if (packed_values.ptr->size() < num_packed_values) {
return MakeFfiStatus(
absl::InvalidArgumentError("insufficient number of packed values."));
}
std::vector<uint64_t> unpacked_messages =
rlwe::UnpackMessagesFlat<secure_aggregation::Integer,
secure_aggregation::BigInteger>(
absl::MakeSpan(*packed_values.ptr).subspan(0, num_packed_values),
packing_base, packing_dimension);
packed_values.ptr->erase(packed_values.ptr->begin(),
packed_values.ptr->begin() + num_packed_values);
for (auto& val : unpacked_messages) {
out.push_back(val);
}
return MakeFfiStatus();
}

FfiStatus Encrypt(const BigIntVectorWrapper& packed_values,
const RnsPolynomialWrapper& secret_key,
const KahePublicParametersWrapper& params,
SingleThreadHkdfWrapper* prng, RnsPolynomialVecWrapper* out) {
// Validate the wrappers.
if (secret_key.ptr == nullptr || params.ptr == nullptr || prng == nullptr ||
prng->ptr == nullptr || out == nullptr) {
if (packed_values.ptr == nullptr || secret_key.ptr == nullptr ||
params.ptr == nullptr || prng == nullptr || prng->ptr == nullptr ||
out == nullptr) {
return MakeFfiStatus(absl::InvalidArgumentError(
secure_aggregation::kNullPointerErrorMessage));
}

// Packing parameters must be valid, e.g. checked on the Rust side.
int num_coeffs = 1 << params.ptr->context->LogN();
std::vector<std::vector<secure_aggregation::BigInteger>> packed_messages =
secure_aggregation::internal::PackMessagesRaw(
input_values.data(), input_values.size(), packing_base, num_packing,
num_coeffs);

auto ciphertext_vec = secure_aggregation::EncodeAndEncryptVector(
packed_messages, *secret_key.ptr, *params.ptr, prng->ptr.get());
*packed_values.ptr, *secret_key.ptr, *params.ptr, prng->ptr.get());

if (!ciphertext_vec.ok()) {
return MakeFfiStatus(ciphertext_vec.status());
Expand All @@ -357,14 +396,13 @@ FfiStatus Encrypt(rust::Slice<const uint64_t> input_values,
return MakeFfiStatus();
}

FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing,
const RnsPolynomialVecWrapper& ciphertexts,
FfiStatus Decrypt(const RnsPolynomialVecWrapper& ciphertexts,
const RnsPolynomialWrapper& secret_key,
const KahePublicParametersWrapper& params,
rust::Slice<uint64_t> output_values, uint64_t* n_written) {
BigIntVectorWrapper* output_values) {
// Validate the wrappers.
if (secret_key.ptr == nullptr || params.ptr == nullptr ||
ciphertexts.ptr == nullptr || n_written == nullptr) {
ciphertexts.ptr == nullptr || output_values == nullptr) {
return MakeFfiStatus(absl::InvalidArgumentError(
secure_aggregation::kNullPointerErrorMessage));
}
Expand All @@ -378,15 +416,13 @@ FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing,
}
}

auto messages = secure_aggregation::DecodeAndDecryptVector(
auto decrypted_values = secure_aggregation::DecodeAndDecryptVector(
*ciphertexts.ptr, *secret_key.ptr, *params.ptr);
if (!messages.ok()) {
return MakeFfiStatus(messages.status());
if (!decrypted_values.ok()) {
return MakeFfiStatus(decrypted_values.status());
}

*n_written = secure_aggregation::internal::UnpackMessagesRaw(
messages.value(), packing_base, num_packing, output_values.size(),
output_values.data());

output_values->ptr =
std::make_unique<std::vector<secure_aggregation::BigInteger>>(
std::move(decrypted_values.value()));
return MakeFfiStatus();
}
86 changes: 42 additions & 44 deletions shell_wrapper/kahe.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
namespace secure_aggregation {
// Forward-declare types for use by the cxx-generated `kahe.rs.h`.
struct KahePublicParameters;
struct BigIntVector;
} // namespace secure_aggregation

#include "shell_wrapper/kahe.rs.h"
Expand Down Expand Up @@ -93,36 +94,18 @@ absl::StatusOr<RnsPolynomial> DecryptPolynomial(
const RnsPolynomial& a,
absl::Span<const rlwe::PrimeModulus<ModularInt>* const> moduli);

// Packs messages taken from a raw pointer.
// Expects packing_base > 1, num_packing > 0, num_coeffs > 0,
// packing_base^num_packing < std::numeric_limits<BigInteger>::max().
std::vector<std::vector<BigInteger>> PackMessagesRaw(const uint64_t* messages,
int num_messages,
uint64_t packing_base,
int num_packing,
int num_coeffs);

// Unpacks messages into a buffer `output_values` of length
// at least `output_values_length`. Returns the elements written to the buffer
// (0 if it didn't write anything).
// Expects packing_base > 1, num_packing > 0, num_coeffs > 0,
// packing_base^num_packing < std::numeric_limits<BigInteger>::max().
int UnpackMessagesRaw(
const std::vector<std::vector<BigInteger>>& packed_messages,
Integer packing_base, int num_packing, int output_values_length,
Integer* output_values);

} // namespace internal

// Encrypts a vector of packed messages, where each coordinate is a vector of
// integers that will be encoded into a single polynomial.
// Encrypts a vector of packed messages, where the packed messages are first
// encoded into plaintext polynomials and then encrypted.
absl::StatusOr<std::vector<RnsPolynomial>> EncodeAndEncryptVector(
std::vector<std::vector<BigInteger>>& packed_messages,
const std::vector<BigInteger>& packed_values,
const RnsPolynomial& secret_key, const KahePublicParameters& params,
Prng* prng);

// Decrypts a vector of ciphertexts.
absl::StatusOr<std::vector<std::vector<BigInteger>>> DecodeAndDecryptVector(
// Decrypts a vector of ciphertexts, and returns the concatenated vector of
// decrypted messages.
absl::StatusOr<std::vector<BigInteger>> DecodeAndDecryptVector(
absl::Span<const RnsPolynomial> ciphertexts,
const RnsPolynomial& secret_key, const KahePublicParameters& params);

Expand Down Expand Up @@ -158,32 +141,47 @@ FfiStatus GenerateSecretKeyWrapper(const KahePublicParametersWrapper& params,
SingleThreadHkdfWrapper* prng,
RnsPolynomialWrapper* out);

// Packs, encodes and encrypts the messages contained in the `input_values`
// buffer. `packing_base` and `num_packing` are the parameters for packing: the
// encoder takes large modular integers obtained by combining `num_packing`
// smaller uint64_t values, each of which is less than `packing_base`. If
// successful, returns OK and sets *out to a vector of ciphertexts, each of
// which is a polynomial.
FfiStatus Encrypt(rust::Slice<const uint64_t> input_values,
uint64_t packing_base, uint64_t num_packing,
// Packs `messages` into a vector of BigIntegers using base `packing_base`
// encoding, where the packed values are appended to `packed_values`.
// Expects `packed_values` to be a valid pointer but the underlying vector
// may be unallocated, and expects packing_base > 1, packing_dimension > 0,
// num_coeffs > 0, packing_base^packing_dimension <
// std::numeric_limits<BigInteger>::max().
// Note that `messages` is effectively padded with zeros to the nearest multiple
// of `packing_dimension` before packing.
FfiStatus PackMessagesRaw(rust::Slice<const uint64_t> messages,
uint64_t packing_base, uint64_t packing_dimension,
uint64_t num_packed_values,
BigIntVectorWrapper* packed_values);

// Unpacks messages stored at `packed_values[0..num_packed_values]` and appends
// them to `out`, and removes these packed values from `packed_values`.
// Expects `packed_values.ptr` to be a valid pointer to the vector of packed
// values, and expects packing_base > 1, packing_dimension > 0,
// num_packed_values > 0, packing_base^packing_dimension <
// std::numeric_limits<BigInteger>::max().
FfiStatus UnpackMessagesRaw(uint64_t packing_base, uint64_t packing_dimension,
uint64_t num_packed_values,
BigIntVectorWrapper& packed_values,
rust::Vec<uint64_t>& out);

// Encrypts the messages contained in `packed_values`. If successful, returns OK
// and sets *out to a vector of ciphertext polynomials.
// Expects `out` to be a valid pointer but the underlying vector may be
// unallocated.
FfiStatus Encrypt(const BigIntVectorWrapper& packed_values,
const RnsPolynomialWrapper& secret_key,
const KahePublicParametersWrapper& params,
SingleThreadHkdfWrapper* prng, RnsPolynomialVecWrapper* out);

// Decrypts, decodes and unpacks `ciphertexts` into the `output_values` buffer.
// Returns a status, and writes the number of outputs written to `n_written`.
// Each decoded message is a large modular integer, which is then split into
// `num_packing` smaller uint64_t values by decomposing into base
// `packing_base`. Note: Internally, decryption yields num_coeffs * num_packing
// * ciphertexts.size() Integers exactly (decryption does padding), with
// num_coeffs = 1 << params->ptr->context->LogN(). But this function can
// take a smaller `output_values_length` to fill a smaller buffer, e.g. to
// remove padding if the plaintext vector length is known.
FfiStatus Decrypt(uint64_t packing_base, uint64_t num_packing,
const RnsPolynomialVecWrapper& ciphertexts,
// Decrypts `ciphertexts` into a vector written to `output_values` buffer, and
// returns a status.
// Expects `output_values` to be a valid pointer but the underlying vector may
// be unallocated.
FfiStatus Decrypt(const RnsPolynomialVecWrapper& ciphertexts,
const RnsPolynomialWrapper& secret_key,
const KahePublicParametersWrapper& params,
rust::Slice<uint64_t> output_values, uint64_t* n_written);
BigIntVectorWrapper* output_values);

} // extern "C"

Expand Down
Loading