diff --git a/willow/proto/willow/BUILD b/willow/proto/willow/BUILD index 2389ae3..cc70ef2 100644 --- a/willow/proto/willow/BUILD +++ b/willow/proto/willow/BUILD @@ -1,4 +1,4 @@ -# Copyright 2021 Google LLC +# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,8 +16,7 @@ load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library") load("@protobuf//bazel:proto_library.bzl", "proto_library") package( - default_applicable_licenses = [ - ], + default_applicable_licenses = ["//third_party/secure_aggregation:license"], default_visibility = ["//visibility:public"], ) @@ -30,3 +29,14 @@ cc_proto_library( name = "decryptor_cc_proto", deps = [":decryptor_proto"], ) + +proto_library( + name = "input_spec_proto", + srcs = ["input_spec.proto"], + compatible_with = ["//buildenv/target:non_prod"], +) + +cc_proto_library( + name = "input_spec_cc_proto", + deps = [":input_spec_proto"], +) diff --git a/willow/proto/willow/input_spec.proto b/willow/proto/willow/input_spec.proto new file mode 100644 index 0000000..7a99da6 --- /dev/null +++ b/willow/proto/willow/input_spec.proto @@ -0,0 +1,83 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +edition = "2023"; + +package third_party_secure_aggregation_willow_proto_willow; + +option java_multiple_files = true; + +// This message describes the result of a group-by query. +// It contains a list of output vectors, each with a name, data type, and role. +// The role is either GROUP_BY or METRIC. The latter represents a metric to be +// aggregated (added). Group-by columns are expected to be of type STRING, while +// metrics are expected to be of type INT64. +message InputSpec { + // Supported data types for output vectors in `ExampleQuerySpec` in + // plan.proto. + enum DataType { + DATA_TYPE_UNSPECIFIED = 0; + INT32 = 1; + INT64 = 2; + BOOL = 3; + FLOAT = 4; + DOUBLE = 5; + BYTES = 6; + STRING = 7; + } + // Defines a domain as an interval. + message Interval { + // The lower bound of the interval. The interval is inclusive. + double min = 1; + // The upper bound of the interval. The interval is inclusive. + double max = 2; + } + + message StringValues { + repeated string values = 1; + } + + // A new message type to represent the domain specification. + message DomainSpec { + oneof domain_type { + // Defines a domain as an ordered list of string values. + StringValues string_values = 1; + + // Defines a domain as an interval of values. + Interval interval = 2; + } + } + + enum RoleType { + ROLE_TYPE_UNSPECIFIED = 0; + GROUP_BY = 1; + METRIC = 2; + } + + message InputVectorSpec { + // The output vector name. + string vector_name = 1; + + // The data type for each entry in the vector. + DataType data_type = 2; + + // The role of the vector in the result, e.g. group-by column. + RoleType role = 3; + + // An optional field to define the domain of the output vector. + // This could be used for validation or other logic. + DomainSpec domain_spec = 4; + } + repeated InputVectorSpec input_vector_specs = 1; +} diff --git a/willow/src/input_encoding/BUILD b/willow/src/input_encoding/BUILD new file mode 100644 index 0000000..7de8fb2 --- /dev/null +++ b/willow/src/input_encoding/BUILD @@ -0,0 +1,48 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_cc//cc:cc_library.bzl", "cc_library") +load("@rules_cc//cc:cc_test.bzl", "cc_test") + +package( + default_applicable_licenses = [ + ], +) + +cc_library( + name = "willow_explicit_encoder", + srcs = ["willow_explicit_encoder.cc"], + hdrs = [ + "willow_encoder_factory.h", + "willow_explicit_encoder.h", + ], + deps = [ + "@abseil-cpp//absl/memory", + "@abseil-cpp//absl/status", + "@abseil-cpp//absl/status:statusor", + "@abseil-cpp//absl/strings", + "//willow/proto/willow:input_spec_cc_proto", + ], +) + +cc_test( + name = "willow_explicit_encoder_test", + srcs = ["willow_explicit_encoder_test.cc"], + deps = [ + ":willow_explicit_encoder", + "@googletest//:gtest_main", + "@abseil-cpp//absl/status", + "//willow/proto/willow:input_spec_cc_proto", + ], +) diff --git a/willow/src/input_encoding/willow_encoder_factory.h b/willow/src/input_encoding/willow_encoder_factory.h new file mode 100644 index 0000000..032b0ba --- /dev/null +++ b/willow/src/input_encoding/willow_encoder_factory.h @@ -0,0 +1,160 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_ENCODER_FACTORY_H_ +#define SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_ENCODER_FACTORY_H_ +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "willow/proto/willow/input_spec.proto.h" +#include "willow/src/input_encoding/willow_explicit_encoder.h" + +namespace secure_aggregation { +namespace willow { + +using ::third_party_secure_aggregation_willow_proto_willow::InputSpec; +using InputVectorSpec = ::third_party_secure_aggregation_willow_proto_willow:: + InputSpec_InputVectorSpec; + +// The maximum size of the Cartesian product of domains for string features. +constexpr int64_t kMaxGlobalOutputDomainSize = 1000000; + +// Factory class that constructs non-copyable instances of children classes of +// WillowInputEncoder. +class WillowInputEncoderFactory { + public: + static absl::Status ValidateInputAndSpec( + const std::unordered_map>& input_data, + const std::unordered_map>& + group_by_data, + const InputSpec input_spec) { + // Check that input_data is not empty + if (input_data.empty()) { + return absl::InvalidArgumentError("input_data must not be empty."); + } + // Check that all provided vectors in input_data, group_by_data, and in + // input_spec have the same length + int l = input_data.begin()->second.size(); + for (const auto& [name, data] : input_data) { + if (data.size() != l) { + return absl::InvalidArgumentError( + "All input and group_by vectors must have the same length."); + } + } + for (const auto& [name, data] : group_by_data) { + if (data.size() != l) { + return absl::InvalidArgumentError( + "All input and group_b vectors must have the same length."); + } + } + + // Check that input_data and group_by_data together have the same keys as + // input_spec, their data types match, and the type is either int or string. + if (input_data.size() + group_by_data.size() != + input_spec.input_vector_specs_size()) { + return absl::InvalidArgumentError( + "input_spec must have the same number of entries as the sum of " + "entries in input_data and group_by_data."); + } + + std::unordered_map spec_map; + for (const auto& spec : input_spec.input_vector_specs()) { + spec_map[spec.vector_name()] = &spec; + } + + for (const auto& [name, data] : input_data) { + auto it = spec_map.find(name); + if (it == spec_map.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Key ", name, " found in input_data but not in input_spec.")); + } + const auto& spec = it->second; + if (spec->data_type() != InputSpec::INT64) { + return absl::InvalidArgumentError( + absl::StrCat("Type mismatch for key ", name, + ": input_data type is int64_t but input_spec type " + "is not INT64.")); + } + } + + for (const auto& [name, data] : group_by_data) { + auto it = spec_map.find(name); + if (it == spec_map.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Key ", name, " found in group_by_data but not in input_spec.")); + } + const auto& spec = it->second; + if (spec->data_type() != InputSpec::STRING) { + return absl::InvalidArgumentError(absl::StrCat( + "Type mismatch for key ", name, + ": group_by_data type is string but input_spec type is " + "not STRING.")); + } + for (const auto& d : data) { + const auto& domain_values = + spec->domain_spec().string_values().values(); + if (std::find(domain_values.begin(), domain_values.end(), d) == + domain_values.end()) { + return absl::InvalidArgumentError(absl::StrCat( + "Domain mismatch for key ", name, ": group_by_data value ", d, + " not found in domain.")); + } + } + } + + // Check that the combined size of the string domains is less than the + // maximum allowed size. + int64_t encoded_domain_size = 1; + for (const auto& [name, _] : group_by_data) { + encoded_domain_size *= + spec_map.at(name)->domain_spec().string_values().values_size(); + if (kMaxGlobalOutputDomainSize < encoded_domain_size) { + return absl::InvalidArgumentError( + "Global output domain size exceeds maximum threshold."); + } + } + return absl::OkStatus(); + } + + // Creates an instance of ExplicitWillowInputEncoder. + static absl::StatusOr> + CreateExplicitWillowInputEncoder( + const std::unordered_map>& input_data, + const std::unordered_map>& + group_by_data, + const InputSpec& input_spec) { + // Check that input_data and input_spec have the same keys, their data + // types match, and the type is either int or string. + absl::Status status = + ValidateInputAndSpec(input_data, group_by_data, input_spec); + if (!status.ok()) { + return status; + } + return absl::WrapUnique( + new WillowInputExplicitEncoder(input_data, group_by_data, input_spec)); + } +}; + +} // namespace willow +} // namespace secure_aggregation + +#endif // SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_ENCODER_FACTORY_H_ diff --git a/willow/src/input_encoding/willow_explicit_encoder.cc b/willow/src/input_encoding/willow_explicit_encoder.cc new file mode 100644 index 0000000..be51f60 --- /dev/null +++ b/willow/src/input_encoding/willow_explicit_encoder.cc @@ -0,0 +1,225 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/src/input_encoding/willow_explicit_encoder.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "willow/proto/willow/input_spec.proto.h" + +namespace secure_aggregation { +namespace willow { + +using InputVectorSpec = ::third_party_secure_aggregation_willow_proto_willow:: + InputSpec_InputVectorSpec; + +struct VectorHasher { + std::size_t operator()(const std::vector& v) const { + std::string digest = ""; + for (int i : v) { + digest += std::to_string(i) + ","; + } + return std::hash()(digest); + } +}; + +// Recursive helper function to generate combinations +void findCombinations( + const std::vector& sizes, std::vector>& result, + std::vector& currentCombination, + std::unordered_map, int, VectorHasher>& inverse_map, + int depth) { + // Base case + if (depth == sizes.size()) { + inverse_map[currentCombination] = result.size(); + result.push_back(currentCombination); + return; + } + + // Recursive step + for (int i = 0; i < sizes[depth]; ++i) { + // Add the current index 'i' to our combination. + currentCombination.push_back(i); + // Recurse to handle the next size in the list. + findCombinations(sizes, result, currentCombination, inverse_map, depth + 1); + // Backtrack: Remove the element we just added. + currentCombination.pop_back(); + } +} + +std::unordered_map, int, VectorHasher> getCombinations( + const std::vector& sizes) { + std::vector> result; + std::vector currentCombination; + std::unordered_map, int, VectorHasher> inverse_map; + findCombinations(sizes, result, currentCombination, inverse_map, 0); + return inverse_map; +} + +absl::StatusOr>> +WillowInputExplicitEncoder::Encode() const { + std::unordered_map spec_map; + for (const auto& spec : input_spec_.input_vector_specs()) { + spec_map[spec.vector_name()] = &spec; + } + + // Define an ordering of the group-by keys. + std::vector sorted_group_by_keys; + sorted_group_by_keys.reserve(group_by_data_.size()); + for (const auto& [key, _] : group_by_data_) { + sorted_group_by_keys.push_back(key); + } + std::sort(sorted_group_by_keys.begin(), sorted_group_by_keys.end()); + + // Collect the sizes of the string domains for each group-by key. + std::vector sizes; + sizes.reserve(sorted_group_by_keys.size()); + for (const auto& key : sorted_group_by_keys) { + sizes.push_back( + spec_map.at(key)->domain_spec().string_values().values_size()); + } + + // Generate all combinations of group-by keys. The value of the + // combination_2_index map is the index of the combination corresponding to + // the vector of indices. + std::unordered_map, int, VectorHasher> combination_2_index = + getCombinations(sizes); + + // Compute the total number of elements in the cartesian product of the + // string domains, which corresponds to the length of the domain once + // encoded as a vector. + int64_t encoded_domain_size = 1; + for (const auto& key : sorted_group_by_keys) { + encoded_domain_size *= + spec_map.at(key)->domain_spec().string_values().values_size(); + } + std::unordered_map> result; + // iterate over input_data + for (const auto& [name, data] : input_data_) { + // initialize the vector for each input_data key + result[name] = std::vector(encoded_domain_size, 0); + std::vector indices; + // iterate over data and copy the corresponding entries to the result + // vector to their location in the encoded domain + for (int i = 0; i < data.size(); ++i) { + indices.clear(); + // iterate over group keys to determine the combination index that + // correspondss to data[i] + for (const auto& g_name : sorted_group_by_keys) { + auto key = group_by_data_.at(g_name)[i]; + // find the index of the key in the string domain. Note that the + // validation ensures that it is present. + int index = -1; + for (int j = 0; + j < + spec_map.at(g_name)->domain_spec().string_values().values_size(); + ++j) { + if (spec_map.at(g_name)->domain_spec().string_values().values(j) == + key) { + index = j; + break; + } + } + indices.push_back(index); + } + result[name][combination_2_index[indices]] = data[i]; + } + } + return result; +} + +absl::StatusOr< + std::pair>, + std::unordered_map>>> +WillowInputExplicitEncoder::Decode( + const std::unordered_map>& encoded_data) + const { + std::unordered_map spec_map; + for (const auto& spec : input_spec_.input_vector_specs()) { + spec_map[spec.vector_name()] = &spec; + } + + std::vector sorted_group_by_keys; + sorted_group_by_keys.reserve(group_by_data_.size()); + for (const auto& [key, _] : group_by_data_) { + sorted_group_by_keys.push_back(key); + } + std::sort(sorted_group_by_keys.begin(), sorted_group_by_keys.end()); + + std::vector sizes; + sizes.reserve(sorted_group_by_keys.size()); + for (const auto& key : sorted_group_by_keys) { + sizes.push_back( + spec_map.at(key)->domain_spec().string_values().values_size()); + } + + std::vector> index_to_combination; + std::vector currentCombination; + std::unordered_map, int, VectorHasher> combination_to_index; + findCombinations(sizes, index_to_combination, currentCombination, + combination_to_index, 0); + + int64_t encoded_domain_size = 1; + for (const auto& key : sorted_group_by_keys) { + encoded_domain_size *= + spec_map.at(key)->domain_spec().string_values().values_size(); + } + + std::unordered_map> decoded_metrics; + std::unordered_map> decoded_groups; + + for (int i = 0; i < encoded_domain_size; ++i) { + bool has_nonzero_metric = false; + for (const auto& [metric_name, data] : encoded_data) { + if (i >= data.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Encoded data for metric ", metric_name, + " has wrong size: expected ", encoded_domain_size, + ", got ", data.size())); + } + if (data.at(i) != 0) { + has_nonzero_metric = true; + break; + } + } + + if (has_nonzero_metric) { + const auto& combination = index_to_combination[i]; + for (int j = 0; j < sorted_group_by_keys.size(); ++j) { + const auto& key_name = sorted_group_by_keys[j]; + int val_idx = combination[j]; + decoded_groups[key_name].push_back( + spec_map.at(key_name)->domain_spec().string_values().values( + val_idx)); + } + for (const auto& [metric_name, data] : encoded_data) { + decoded_metrics[metric_name].push_back(data.at(i)); + } + } + } + return std::make_pair(decoded_metrics, decoded_groups); +} + +} // namespace willow +} // namespace secure_aggregation diff --git a/willow/src/input_encoding/willow_explicit_encoder.h b/willow/src/input_encoding/willow_explicit_encoder.h new file mode 100644 index 0000000..802e97a --- /dev/null +++ b/willow/src/input_encoding/willow_explicit_encoder.h @@ -0,0 +1,85 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_EXPLICIT_ENCODER_H_ +#define SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_EXPLICIT_ENCODER_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "willow/proto/willow/input_spec.proto.h" + +namespace secure_aggregation { +namespace willow { + +using InputSpec = + ::third_party_secure_aggregation_willow_proto_willow::InputSpec; + +class WillowInputEncoder { + public: + virtual ~WillowInputEncoder() = default; + + virtual absl::StatusOr>> + Encode() const = 0; + + virtual absl::StatusOr< + std::pair>, + std::unordered_map>>> + Decode(const std::unordered_map>& + encoded_data) const = 0; +}; + +// WillowInputExplicitEncoder must be instantiated through the factory class +// WillowInputEncoderFactory. +class WillowInputExplicitEncoder : public WillowInputEncoder { + public: + WillowInputExplicitEncoder(const WillowInputExplicitEncoder&) = delete; + WillowInputExplicitEncoder& operator=(const WillowInputExplicitEncoder&) = + delete; + ~WillowInputExplicitEncoder() override = default; + + absl::StatusOr>> Encode() + const override; + + absl::StatusOr< + std::pair>, + std::unordered_map>>> + Decode(const std::unordered_map>& + encoded_data) const override; + + private: + const std::unordered_map> input_data_; + const std::unordered_map> + group_by_data_; + const InputSpec input_spec_; + + WillowInputExplicitEncoder( + const std::unordered_map>& input_data, + const std::unordered_map>& + group_by_data, + const InputSpec& input_spec) + : input_data_(input_data), + group_by_data_(group_by_data), + input_spec_(input_spec) {} + friend class WillowInputEncoderFactory; +}; + +} // namespace willow +} // namespace secure_aggregation + +#endif // SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_EXPLICIT_ENCODER_H_ diff --git a/willow/src/input_encoding/willow_explicit_encoder_test.cc b/willow/src/input_encoding/willow_explicit_encoder_test.cc new file mode 100644 index 0000000..936cd31 --- /dev/null +++ b/willow/src/input_encoding/willow_explicit_encoder_test.cc @@ -0,0 +1,404 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "willow/src/input_encoding/willow_explicit_encoder.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "willow/proto/willow/input_spec.proto.h" +#include "willow/src/input_encoding/willow_encoder_factory.h" + +namespace secure_aggregation { +namespace willow { +namespace { + +using ::testing::ElementsAre; +using ::testing::HasSubstr; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; +using ::testing::status::IsOkAndHolds; +using ::testing::status::StatusIs; + +TEST(WillowInputEncoderFactoryTest, ValidateInputAndSpecLengthMismatch) { + std::unordered_map> input_data; + input_data["metric1"] = {1, 2, 3}; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"a", "b", "a"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + // Missing spec for "feature1" + + EXPECT_THAT( + WillowInputEncoderFactory::ValidateInputAndSpec(input_data, group_by_data, + input_spec), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "input_spec must have the same number of entries as the sum of " + "entries in input_data and group_by_data."))); +} + +TEST(WillowInputEncoderFactoryTest, ValidateInputAndSpecTypeMismatch) { + std::unordered_map> input_data; + input_data["metric1"] = {1, 2, 3}; + std::unordered_map> group_by_data; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::STRING); + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type mismatch for key metric1"))); +} + +TEST(WillowInputEncoderFactoryTest, ValidateInputAndSpecEmptyInputData) { + std::unordered_map> input_data; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"a", "b", "a"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("feature1"); + spec1->set_data_type(InputSpec::STRING); + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("input_data must not be empty"))); +} + +TEST(WillowInputEncoderFactoryTest, ValidateInputAndSpecDomainValueNotFound) { + std::unordered_map> input_data; + input_data["metric1"] = {1}; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"c"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + InputSpec::InputVectorSpec* spec2 = input_spec.add_input_vector_specs(); + spec2->set_vector_name("feature1"); + spec2->set_data_type(InputSpec::STRING); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("a"); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("b"); + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Domain mismatch for key feature1"))); +} + +TEST(WillowInputEncoderFactoryTest, + ValidateInputAndSpecInputDataVectorLengthMismatch) { + std::unordered_map> input_data; + input_data["metric1"] = {1, 2, 3}; + input_data["metric2"] = {1, 2}; + std::unordered_map> group_by_data; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + InputSpec::InputVectorSpec* spec2 = input_spec.add_input_vector_specs(); + spec2->set_vector_name("metric2"); + spec2->set_data_type(InputSpec::INT64); + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("must have the same length"))); +} + +TEST(WillowInputEncoderFactoryTest, + ValidateInputAndSpecGroupByDataVectorLengthMismatch) { + std::unordered_map> input_data; + input_data["metric1"] = {1, 2, 3}; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"a", "b"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + InputSpec::InputVectorSpec* spec2 = input_spec.add_input_vector_specs(); + spec2->set_vector_name("feature1"); + spec2->set_data_type(InputSpec::STRING); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("a"); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("b"); + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("must have the same length"))); +} + +TEST(WillowInputEncoderFactoryTest, + ValidateInputAndSpecDomainSizeVectorLengthMismatch) { + std::unordered_map> input_data; + input_data["metric1"] = {1, 2, 3}; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"a", "b", "c"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + spec1->mutable_domain_spec()->mutable_string_values()->add_values("x"); + InputSpec::InputVectorSpec* spec2 = input_spec.add_input_vector_specs(); + spec2->set_vector_name("feature1"); + spec2->set_data_type(InputSpec::STRING); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("a"); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("b"); + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Domain mismatch for key feature1: " + "group_by_data value c not found in domain"))); +} + +TEST(WillowInputEncoderFactoryTest, ValidateInputAndSpecInputKeyNotInSpec) { + std::unordered_map> input_data; + input_data["metric1"] = {1}; + input_data["metric2"] = {2}; + std::unordered_map> group_by_data; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + spec1->mutable_domain_spec()->mutable_string_values()->add_values("x"); + + EXPECT_THAT( + WillowInputEncoderFactory::ValidateInputAndSpec(input_data, group_by_data, + input_spec), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "input_spec must have the same number of entries as the sum of " + "entries in input_data and group_by_data."))); +} + +TEST(WillowInputEncoderFactoryTest, ValidateInputAndSpecGroupByKeyNotInSpec) { + std::unordered_map> input_data; + input_data["metric1"] = {1}; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"a"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + spec1->mutable_domain_spec()->mutable_string_values()->add_values("x"); + + EXPECT_THAT( + WillowInputEncoderFactory::ValidateInputAndSpec(input_data, group_by_data, + input_spec), + StatusIs( + absl::StatusCode::kInvalidArgument, + HasSubstr( + "input_spec must have the same number of entries as the sum of " + "entries in input_data and group_by_data."))); +} + +TEST(WillowInputEncoderFactoryTest, ValidateInputAndSpecGroupByTypeMismatch) { + std::unordered_map> input_data; + input_data["metric1"] = {1}; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"a"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + spec1->mutable_domain_spec()->mutable_string_values()->add_values("x"); + InputSpec::InputVectorSpec* spec2 = input_spec.add_input_vector_specs(); + spec2->set_vector_name("feature1"); + spec2->set_data_type(InputSpec::INT64); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("y"); + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Type mismatch for key feature1"))); +} + +TEST(WillowInputEncoderFactoryTest, + ValidateInputAndSpecGlobalDomainSizeExceeded) { + std::unordered_map> input_data; + input_data["metric1"] = {1}; + std::unordered_map> group_by_data; + group_by_data["feature1"] = {"a"}; + InputSpec input_spec; + InputSpec::InputVectorSpec* spec1 = input_spec.add_input_vector_specs(); + spec1->set_vector_name("metric1"); + spec1->set_data_type(InputSpec::INT64); + spec1->mutable_domain_spec()->mutable_string_values()->add_values("1, 2, 3"); + InputSpec::InputVectorSpec* spec2 = input_spec.add_input_vector_specs(); + spec2->set_vector_name("feature1"); + spec2->set_data_type(InputSpec::STRING); + spec2->mutable_domain_spec()->mutable_string_values()->add_values("a"); + for (int i = 0; i < 1000000; ++i) { + spec2->mutable_domain_spec()->mutable_string_values()->add_values( + std::to_string(i)); + } + + EXPECT_THAT(WillowInputEncoderFactory::ValidateInputAndSpec( + input_data, group_by_data, input_spec), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Global output domain size exceeds"))); +} + +InputSpec::InputVectorSpec CreateStringSpec( + const std::string& name, const std::vector& domain) { + InputSpec::InputVectorSpec spec; + spec.set_vector_name(name); + spec.set_data_type(InputSpec::STRING); + for (const auto& val : domain) { + spec.mutable_domain_spec()->mutable_string_values()->add_values(val); + } + return spec; +} + +InputSpec::InputVectorSpec CreateIntSpec(const std::string& name) { + InputSpec::InputVectorSpec spec; + spec.set_vector_name(name); + spec.set_data_type(InputSpec::INT64); + return spec; +} + +TEST(WillowInputEncoderFactoryTest, EncodeSimpleGroupBy) { + std::unordered_map> input_data; + input_data["metric1"] = {10, 20, 5}; + std::unordered_map> group_by_data; + group_by_data["country"] = {"US", "CA", "US"}; + group_by_data["lang"] = {"en", "es", "es"}; + InputSpec input_spec; + *input_spec.add_input_vector_specs() = CreateIntSpec("metric1"); + *input_spec.add_input_vector_specs() = CreateStringSpec( + "country", {"CA", "GB", "MX", "US"}); // CA=0, GB=1, MX=2, US=3 + *input_spec.add_input_vector_specs() = + CreateStringSpec("lang", {"en", "es"}); // en=0, es=1 + + // group_by keys are sorted: "country", "lang" + // value_to_index_maps["country"]: {"CA":0, "GB":1, "MX":2, "US":3} + // value_to_index_maps["lang"]: {"en":0, "es":1} + + // Row 0: country=US(3), lang=en(0). metric1=10. + // combo_index = 3*2 + 0 = 6 + // Row 1: country=CA(0), lang=es(1). metric1=20. + // combo_index = 0*2 + 1 = 1 + // Row 2: country=US(3), lang=es(1). metric1=5. + // combo_index = 3*2 + 1 = 7 + + // Expected histogram for metric1: + // Index 0 (CA, en): 0 + // Index 1 (CA, es): 20 + // Index 2 (GB, en): 0 + // Index 3 (GB, es): 0 + // Index 4 (MX, en): 0 + // Index 5 (MX, es): 0 + // Index 6 (US, en): 10 + // Index 7 (US, es): 5 + // Result: [0, 20, 0, 0, 0, 0, 10, 5] + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder, + WillowInputEncoderFactory::CreateExplicitWillowInputEncoder( + input_data, group_by_data, input_spec)); + + EXPECT_THAT(encoder->Encode(), + IsOkAndHolds(UnorderedElementsAre( + Pair("metric1", ElementsAre(0, 20, 0, 0, 0, 0, 10, 5))))); +} + +TEST(WillowInputEncoderFactoryTest, EncodeTwoMetricsOneGroupBy) { + std::unordered_map> input_data; + input_data["metric1"] = {10, 20}; + input_data["metric2"] = {100, 200}; + std::unordered_map> group_by_data; + group_by_data["country"] = {"US", "CA"}; + InputSpec input_spec; + *input_spec.add_input_vector_specs() = CreateIntSpec("metric1"); + *input_spec.add_input_vector_specs() = CreateIntSpec("metric2"); + *input_spec.add_input_vector_specs() = + CreateStringSpec("country", {"CA", "US"}); // CA=0, US=1 + + // group_by keys are sorted: "country" + // value_to_index_maps["country"]: {"CA":0, "US":1} + // combinations: {0}->0, {1}->1 + + // Row 0: country=US(1), metric1=10, metric2=100. + // combo_index for {1} is 1. + // result["metric1"][1]=10, result["metric2"][1]=100 + // Row 1: country=CA(0), metric1=20, metric2=200. + // combo_index for {0} is 0. + // result["metric1"][0]=20, result["metric2"][0]=200 + + // Expected: + // metric1: [20, 10] + // metric2: [200, 100] + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder, + WillowInputEncoderFactory::CreateExplicitWillowInputEncoder( + input_data, group_by_data, input_spec)); + + EXPECT_THAT(encoder->Encode(), IsOkAndHolds(UnorderedElementsAre( + Pair("metric1", ElementsAre(20, 10)), + Pair("metric2", ElementsAre(200, 100))))); +} + +TEST(WillowInputEncoderFactoryTest, EncodeThenDecode) { + std::unordered_map> input_data; + input_data["metric1"] = {10, 20, 5}; + std::unordered_map> group_by_data; + group_by_data["country"] = {"US", "CA", "US"}; + group_by_data["lang"] = {"en", "es", "es"}; + InputSpec input_spec; + *input_spec.add_input_vector_specs() = CreateIntSpec("metric1"); + *input_spec.add_input_vector_specs() = + CreateStringSpec("country", {"CA", "GB", "MX", "US"}); + *input_spec.add_input_vector_specs() = CreateStringSpec("lang", {"en", "es"}); + + ASSERT_OK_AND_ASSIGN( + std::unique_ptr encoder, + WillowInputEncoderFactory::CreateExplicitWillowInputEncoder( + input_data, group_by_data, input_spec)); + + ASSERT_OK_AND_ASSIGN(auto encoded_data, encoder->Encode()); + ASSERT_OK_AND_ASSIGN(auto decoded_pair, encoder->Decode(encoded_data)); + + const auto& decoded_metrics = decoded_pair.first; + const auto& decoded_groups = decoded_pair.second; + + // The decoded output is sparse and only contains rows with non-zero metrics. + // The order depends on iteration over dense vector. + // metric1 values for combo indices 1,6,7 are 20,10,5. + // The decoded result should contain 3 rows in order of combination index. + // combo 1: CA, es, metric1=20 + // combo 6: US, en, metric1=10 + // combo 7: US, es, metric1=5 + EXPECT_THAT(decoded_metrics.at("metric1"), ElementsAre(20, 10, 5)); + EXPECT_THAT(decoded_groups.at("country"), ElementsAre("CA", "US", "US")); + EXPECT_THAT(decoded_groups.at("lang"), ElementsAre("es", "en", "es")); +} + +} // namespace +} // namespace willow +} // namespace secure_aggregation