Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")


# gRPC
# must be included separately, since we need to load transitive deps of grpc.
http_archive(
name = "com_github_grpc_grpc",
strip_prefix = "grpc-1.55.0",
urls = [
"https://github.com/grpc/grpc/archive/refs/tags/v1.55.0.zip",
],
)
load("@com_github_grpc_grpc//bazel:grpc_deps.bzl", "grpc_deps")
grpc_deps()
load("@com_github_grpc_grpc//bazel:grpc_extra_deps.bzl", "grpc_extra_deps")
grpc_extra_deps()

# rules_proto defines abstract rules for building Protocol Buffers.
# https://github.com/bazelbuild/rules_proto
http_archive(
Expand Down Expand Up @@ -47,8 +62,6 @@ load("@io_bazel_rules_go//go:deps.bzl", "go_register_toolchains", "go_rules_depe

go_rules_dependencies()

go_register_toolchains(version = "1.19.3")

# Install gtest.
# https://github.com/google/googletest
http_archive(
Expand Down
240 changes: 31 additions & 209 deletions dpf/distributed_point_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
#include "absl/types/span.h"
#include "dpf/aes_128_fixed_key_hash.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/internal/evaluate_prg_hwy.h"
#include "dpf/internal/maybe_deref_span.h"
#include "dpf/internal/proto_validator.h"
#include "dpf/internal/value_type_helpers.h"
#include "hwy/aligned_allocator.h"
Expand Down Expand Up @@ -232,17 +230,9 @@ class DistributedPointFunction {
// Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
// for all types in the `DpfParameters` passed at construction.

// Legacy interface for absl::uint128, which doesn't require explicitly
// converting to absl::Span<const absl::uint128>.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, const std::vector<absl::uint128>& beta) {
return GenerateKeysIncremental(alpha, absl::MakeConstSpan(beta));
}

// Templated version when all value types are equal.
template <typename T>
// Overload for simple integers.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, absl::Span<const T> beta) {
absl::uint128 alpha, absl::Span<const absl::uint128> beta) {
std::vector<Value> values(beta.size());
for (int i = 0; i < static_cast<int>(beta.size()); ++i) {
absl::StatusOr<Value> value = ToValue(beta[i]);
Expand Down Expand Up @@ -377,35 +367,6 @@ class DistributedPointFunction {
&ctx);
}

// Evaluates a span of DPF keys. The i-th key is evaluated at
// evaluation_points[i]. After each hierarchy level, calls `op` on the output
// at that hierarchy level. `op` must be callable with the following
// signature:
//
// op(int hierarchy_level, absl::Span<T> values)
//
// It should return a value that is implicitly convertible to `bool`.
//
// This method is intended for use cases similar to
//
// absl::StatusOr<std::vector<T>> EvaluateAt(
// int hierarchy_level, absl::Span<const absl::uint128> evaluation_points,
// EvaluationContext& ctx)
//
// but without the overhead of EvaluationContext. Instead, all operations on
// intermediate values, and obtaining the final result, should be done via
// `op`.
//
// Return absl::OkStatus() after successfully evaluating `op` on the last
// hierarchy level, or as soon as `op` returns `false`. Returns
// INVALID_ARGUMENT in case any `key` is malformed, or if any of the
// `evaluation_points` are out of range.
template <typename T, typename Fn>
absl::Status EvaluateAndApply(
dpf_internal::MaybeDerefSpan<const DpfKey>,
absl::Span<const absl::uint128> evaluation_points, Fn op,
int evaluation_points_rightshift = 0) const;

// Returns the DpfParameters of this DPF.
inline absl::Span<const DpfParameters> parameters() const {
return parameters_;
Expand Down Expand Up @@ -573,13 +534,6 @@ class DistributedPointFunction {
absl::flat_hash_map<std::string, ValueCorrectionFunction>&
value_correction_functions);

// For the given `key` and `hierarchy_level`, returns the value correction
// words as an array of integers, where the size of the array matches the
// number of batched elements per block.
template <typename T>
absl::StatusOr<std::array<T, dpf_internal::ElementsPerBlock<T>()>>
GetValueCorrectionAsArray(const DpfKey& key, int hierarchy_level) const;

// Joint implementation of the two variants of `EvaluateAt<T>`. If `ctx !=
// NULL`, `key` must point to `ctx->key()`, and `*ctx` will be updated with
// the partial evaluations at this `hierarchy_level`.
Expand Down Expand Up @@ -636,6 +590,8 @@ class DistributedPointFunction {
// correct values for it anyway.
absl::flat_hash_map<std::string, ValueCorrectionFunction>
value_correction_functions_;

friend class KeyGenerationProtocol;
};

//========================//
Expand Down Expand Up @@ -728,7 +684,7 @@ absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil(
int previous_log_domain_size = 0;
int previous_hierarchy_level = ctx.previous_hierarchy_level();
if (!prefixes.empty()) {
DCHECK_GE(ctx.previous_hierarchy_level(), 0);
DCHECK(ctx.previous_hierarchy_level() >= 0);
previous_log_domain_size =
parameters_[previous_hierarchy_level].log_domain_size();
for (absl::uint128 prefix : prefixes) {
Expand Down Expand Up @@ -864,7 +820,7 @@ absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil(
// Compute the number of outputs we will have. For each prefix, we will have a
// full expansion from the previous heirarchy level to the current heirarchy
// level.
DCHECK_LT(log_domain_size - previous_log_domain_size, 63);
DCHECK(log_domain_size - previous_log_domain_size < 63);
int64_t outputs_per_prefix = int64_t{1}
<< (log_domain_size - previous_log_domain_size);

Expand All @@ -890,26 +846,6 @@ absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil(
}
}

template <typename T>
absl::StatusOr<std::array<T, dpf_internal::ElementsPerBlock<T>()>>
DistributedPointFunction::GetValueCorrectionAsArray(const DpfKey& key,
int hierarchy_level) const {
// Get output correction word from `key`.
const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
value_correction =
&(key.correction_words(hierarchy_to_tree_[hierarchy_level])
.value_correction());
} else {
// Last level value correction is stored in an extra proto field, since we
// have one less correction word than tree levels.
value_correction = &(key.last_level_value_correction());
}

// Split output correction into elements of type T, and return it.
return dpf_internal::ValuesToArray<T>(*value_correction);
}

template <typename T>
absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAtImpl(
const DpfKey& key, int hierarchy_level,
Expand Down Expand Up @@ -954,11 +890,31 @@ absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAtImpl(
return std::vector<T>{}; // Nothing to do.
}

// Get output correction word from `key`.
constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
value_correction =
&(key.correction_words(hierarchy_to_tree_[hierarchy_level])
.value_correction());
} else {
// Last level value correction is stored in an extra proto field, since we
// have one less correction word than tree levels.
value_correction = &(key.last_level_value_correction());
}

// Split output correction into elements of type T, and save it in
// correction_ints.
absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
dpf_internal::ValuesToArray<T>(*value_correction);
if (!correction_ints.ok()) {
return correction_ints.status();
}

// Split up evaluation_points into tree indices and block indices, if we're
// operating on a packed type. Otherwise set `tree_indices` to
// `evaluation_points`.
hwy::AlignedFreeUniquePtr<absl::uint128[]> maybe_recomputed_tree_indices;
constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
absl::Span<const absl::uint128> tree_indices;
if (elements_per_block > 1) {
maybe_recomputed_tree_indices =
Expand Down Expand Up @@ -1026,22 +982,16 @@ absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAtImpl(
}
DCHECK(static_cast<int64_t>(seeds.size()) == num_evaluation_points);

// Hash `seeds`.
// Hash DPF evaluations.
absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>> hashed_expansion =
HashExpandedSeeds(hierarchy_level, seeds);
if (!hashed_expansion.ok()) {
return hashed_expansion.status();
}

// Get value correction words.
absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
GetValueCorrectionAsArray<T>(key, hierarchy_level);
if (!correction_ints.ok()) {
return correction_ints.status();
}

// Perform value correction.
std::vector<T> result(num_evaluation_points);
std::vector<T> result;
result.reserve(num_evaluation_points);
const int blocks_needed = blocks_needed_[hierarchy_level];
for (int64_t i = 0; i < num_evaluation_points; ++i) {
std::array<T, elements_per_block> current_elements =
Expand All @@ -1053,7 +1003,7 @@ absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAtImpl(
if (elements_per_block > 1) {
block_index = DomainToBlockIndex(evaluation_points[i], hierarchy_level);
}
result[i] = current_elements[block_index];
result.push_back(current_elements[block_index]);
if (selected_partial_evaluations->control_bits[i]) {
result[i] += (*correction_ints)[block_index];
}
Expand All @@ -1069,134 +1019,6 @@ absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAtImpl(
return result;
}

template <typename T, typename Fn>
absl::Status DistributedPointFunction::EvaluateAndApply(
dpf_internal::MaybeDerefSpan<const DpfKey> keys,
absl::Span<const absl::uint128> evaluation_points, Fn op,
int evaluation_points_rightshift) const {
if (evaluation_points.size() != keys.size()) {
return absl::InvalidArgumentError(
"`keys.size()` != `evaluation_points.size()`");
}
for (int i = 0; i < keys.size(); ++i) {
absl::Status status = proto_validator_->ValidateDpfKey(keys[i]);
if (!status.ok()) return status;
}

const int64_t num_keys = keys.size();
const int num_hierarchy_levels = parameters_.size();
DpfExpansion eval;
eval.control_bits.resize(num_keys);
eval.seeds = hwy::AllocateAligned<absl::uint128>(num_keys);
if (eval.seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
absl::Span<absl::uint128> seeds(eval.seeds.get(), num_keys);
absl::Span<bool> control_bits(eval.control_bits);
hwy::AlignedFreeUniquePtr<absl::uint128[]> correction_seeds;
BitVector correction_control_bits_left, correction_control_bits_right;
std::vector<T> values(num_keys);

// Initialize seeds and control bits.
for (int64_t i = 0; i < num_keys; ++i) {
seeds[i] = absl::MakeUint128(keys[i].seed().high(), keys[i].seed().low());
control_bits[i] = keys[i].party();
}

int start_level = 0;
int stop_level = hierarchy_to_tree_[0];
for (int hierarchy_level = 0; hierarchy_level < num_hierarchy_levels;
++hierarchy_level) {
if (hierarchy_level > 0) {
start_level = stop_level;
stop_level = hierarchy_to_tree_[hierarchy_level];
}

// Compute index shifts for the current level.
const int domain_index_rightshift =
evaluation_points_rightshift + parameters_.back().log_domain_size() -
parameters_[hierarchy_level].log_domain_size();
const int tree_index_rightshift = evaluation_points_rightshift +
parameters_.back().log_domain_size() -
hierarchy_to_tree_[hierarchy_level];

int num_tree_levels = stop_level - start_level;
if (num_tree_levels > 0) {
correction_seeds =
hwy::AllocateAligned<absl::uint128>(num_tree_levels * num_keys);
if (correction_seeds == nullptr) {
return absl::ResourceExhaustedError("Memory allocation error");
}
correction_control_bits_left.resize(num_tree_levels * num_keys);
correction_control_bits_right.resize(num_tree_levels * num_keys);
for (int i = 0; i < num_tree_levels; ++i) {
for (int64_t j = 0; j < num_keys; ++j) {
const int64_t index = i * num_keys + j;
const CorrectionWord& cw = keys[j].correction_words(start_level + i);
correction_seeds[index] =
absl::MakeUint128(cw.seed().high(), cw.seed().low());
correction_control_bits_left[index] = cw.control_left();
correction_control_bits_right[index] = cw.control_right();
}
}

// Evaluate the current hierarchy level for all keys.
absl::Status status = dpf_internal::EvaluateSeeds(
seeds.size(), num_tree_levels, num_tree_levels * num_keys,
seeds.data(), control_bits.data(), evaluation_points.data(),
tree_index_rightshift, correction_seeds.get(),
correction_control_bits_left.data(),
correction_control_bits_right.data(), prg_left_, prg_right_,
seeds.data(), control_bits.data());
if (!status.ok()) {
return status;
}
}

// Hash `seeds`.
absl::StatusOr<hwy::AlignedFreeUniquePtr<absl::uint128[]>>
hashed_expansion = HashExpandedSeeds(hierarchy_level, seeds);
if (!hashed_expansion.ok()) {
return hashed_expansion.status();
}

// Compute value correction for the current level.
constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
const int blocks_needed = blocks_needed_[hierarchy_level];
for (int64_t i = 0; i < num_keys; ++i) {
std::array<T, elements_per_block> current_elements =
dpf_internal::ConvertBytesToArrayOf<T>(absl::string_view(
reinterpret_cast<const char*>(hashed_expansion->get() +
i * blocks_needed),
blocks_needed * sizeof(absl::uint128)));
absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
GetValueCorrectionAsArray<T>(keys[i], hierarchy_level);
if (!correction_ints.ok()) {
return correction_ints.status();
}
int block_index = 0;
if (elements_per_block > 1 && domain_index_rightshift < 128) {
block_index = DomainToBlockIndex(
evaluation_points[i] >> domain_index_rightshift, hierarchy_level);
}
values[i] = current_elements[block_index];
if (control_bits[i]) {
values[i] += (*correction_ints)[block_index];
}
if (keys[i].party() == 1) {
values[i] = -values[i];
}
}

// Call the callback with the values at the current level, and return if the
// result is `false`.
if (!op(values)) {
break;
}
}
return absl::OkStatus();
}

} // namespace distributed_point_functions

#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
Loading