Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions willow/proto/willow/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021 Google LLC
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,8 +16,7 @@ load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
load("@protobuf//bazel:proto_library.bzl", "proto_library")

package(
default_applicable_licenses = [
],
default_applicable_licenses = ["//third_party/secure_aggregation:license"],
default_visibility = ["//visibility:public"],
)

Expand All @@ -30,3 +29,14 @@ cc_proto_library(
name = "decryptor_cc_proto",
deps = [":decryptor_proto"],
)

proto_library(
name = "input_spec_proto",
srcs = ["input_spec.proto"],
compatible_with = ["//buildenv/target:non_prod"],
)

cc_proto_library(
name = "input_spec_cc_proto",
deps = [":input_spec_proto"],
)
83 changes: 83 additions & 0 deletions willow/proto/willow/input_spec.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

edition = "2023";

package third_party_secure_aggregation_willow_proto_willow;

option java_multiple_files = true;

// This message describes the result of a group-by query.
// It contains a list of output vectors, each with a name, data type, and role.
// The role is either GROUP_BY or METRIC. The latter represents a metric to be
// aggregated (added). Group-by columns are expected to be of type STRING, while
// metrics are expected to be of type INT64.
message InputSpec {
// Supported data types for output vectors in `ExampleQuerySpec` in
// plan.proto.
enum DataType {
DATA_TYPE_UNSPECIFIED = 0;
INT32 = 1;
INT64 = 2;
BOOL = 3;
FLOAT = 4;
DOUBLE = 5;
BYTES = 6;
STRING = 7;
}
// Defines a domain as an interval.
message Interval {
// The lower bound of the interval. The interval is inclusive.
double min = 1;
// The upper bound of the interval. The interval is inclusive.
double max = 2;
}

message StringValues {
repeated string values = 1;
}

// A new message type to represent the domain specification.
message DomainSpec {
oneof domain_type {
// Defines a domain as an ordered list of string values.
StringValues string_values = 1;

// Defines a domain as an interval of values.
Interval interval = 2;
}
}

enum RoleType {
ROLE_TYPE_UNSPECIFIED = 0;
GROUP_BY = 1;
METRIC = 2;
}

message InputVectorSpec {
// The output vector name.
string vector_name = 1;

// The data type for each entry in the vector.
DataType data_type = 2;

// The role of the vector in the result, e.g. group-by column.
RoleType role = 3;

// An optional field to define the domain of the output vector.
// This could be used for validation or other logic.
DomainSpec domain_spec = 4;
}
repeated InputVectorSpec input_vector_specs = 1;
}
48 changes: 48 additions & 0 deletions willow/src/input_encoding/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_cc//cc:cc_test.bzl", "cc_test")

package(
default_applicable_licenses = [
],
)

cc_library(
name = "willow_explicit_encoder",
srcs = ["willow_explicit_encoder.cc"],
hdrs = [
"willow_encoder_factory.h",
"willow_explicit_encoder.h",
],
deps = [
"@abseil-cpp//absl/memory",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/status:statusor",
"@abseil-cpp//absl/strings",
"//willow/proto/willow:input_spec_cc_proto",
],
)

cc_test(
name = "willow_explicit_encoder_test",
srcs = ["willow_explicit_encoder_test.cc"],
deps = [
":willow_explicit_encoder",
"@googletest//:gtest_main",
"@abseil-cpp//absl/status",
"//willow/proto/willow:input_spec_cc_proto",
],
)
160 changes: 160 additions & 0 deletions willow/src/input_encoding/willow_encoder_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_ENCODER_FACTORY_H_
#define SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_ENCODER_FACTORY_H_
#include <algorithm>
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "willow/proto/willow/input_spec.proto.h"
#include "willow/src/input_encoding/willow_explicit_encoder.h"

namespace secure_aggregation {
namespace willow {

using ::third_party_secure_aggregation_willow_proto_willow::InputSpec;
using InputVectorSpec = ::third_party_secure_aggregation_willow_proto_willow::
InputSpec_InputVectorSpec;

// The maximum size of the Cartesian product of domains for string features.
constexpr int64_t kMaxGlobalOutputDomainSize = 1000000;

// Factory class that constructs non-copyable instances of children classes of
// WillowInputEncoder.
class WillowInputEncoderFactory {
public:
static absl::Status ValidateInputAndSpec(
const std::unordered_map<std::string, std::vector<int64_t>>& input_data,
const std::unordered_map<std::string, std::vector<std::string>>&
group_by_data,
const InputSpec input_spec) {
// Check that input_data is not empty
if (input_data.empty()) {
return absl::InvalidArgumentError("input_data must not be empty.");
}
// Check that all provided vectors in input_data, group_by_data, and in
// input_spec have the same length
int l = input_data.begin()->second.size();
for (const auto& [name, data] : input_data) {
if (data.size() != l) {
return absl::InvalidArgumentError(
"All input and group_by vectors must have the same length.");
}
}
for (const auto& [name, data] : group_by_data) {
if (data.size() != l) {
return absl::InvalidArgumentError(
"All input and group_b vectors must have the same length.");
}
}

// Check that input_data and group_by_data together have the same keys as
// input_spec, their data types match, and the type is either int or string.
if (input_data.size() + group_by_data.size() !=
input_spec.input_vector_specs_size()) {
return absl::InvalidArgumentError(
"input_spec must have the same number of entries as the sum of "
"entries in input_data and group_by_data.");
}

std::unordered_map<std::string, const InputVectorSpec*> spec_map;
for (const auto& spec : input_spec.input_vector_specs()) {
spec_map[spec.vector_name()] = &spec;
}

for (const auto& [name, data] : input_data) {
auto it = spec_map.find(name);
if (it == spec_map.end()) {
return absl::InvalidArgumentError(absl::StrCat(
"Key ", name, " found in input_data but not in input_spec."));
}
const auto& spec = it->second;
if (spec->data_type() != InputSpec::INT64) {
return absl::InvalidArgumentError(
absl::StrCat("Type mismatch for key ", name,
": input_data type is int64_t but input_spec type "
"is not INT64."));
}
}

for (const auto& [name, data] : group_by_data) {
auto it = spec_map.find(name);
if (it == spec_map.end()) {
return absl::InvalidArgumentError(absl::StrCat(
"Key ", name, " found in group_by_data but not in input_spec."));
}
const auto& spec = it->second;
if (spec->data_type() != InputSpec::STRING) {
return absl::InvalidArgumentError(absl::StrCat(
"Type mismatch for key ", name,
": group_by_data type is string but input_spec type is "
"not STRING."));
}
for (const auto& d : data) {
const auto& domain_values =
spec->domain_spec().string_values().values();
if (std::find(domain_values.begin(), domain_values.end(), d) ==
domain_values.end()) {
return absl::InvalidArgumentError(absl::StrCat(
"Domain mismatch for key ", name, ": group_by_data value ", d,
" not found in domain."));
}
}
}

// Check that the combined size of the string domains is less than the
// maximum allowed size.
int64_t encoded_domain_size = 1;
for (const auto& [name, _] : group_by_data) {
encoded_domain_size *=
spec_map.at(name)->domain_spec().string_values().values_size();
if (kMaxGlobalOutputDomainSize < encoded_domain_size) {
return absl::InvalidArgumentError(
"Global output domain size exceeds maximum threshold.");
}
}
return absl::OkStatus();
}

// Creates an instance of ExplicitWillowInputEncoder.
static absl::StatusOr<std::unique_ptr<WillowInputExplicitEncoder>>
CreateExplicitWillowInputEncoder(
const std::unordered_map<std::string, std::vector<int64_t>>& input_data,
const std::unordered_map<std::string, std::vector<std::string>>&
group_by_data,
const InputSpec& input_spec) {
// Check that input_data and input_spec have the same keys, their data
// types match, and the type is either int or string.
absl::Status status =
ValidateInputAndSpec(input_data, group_by_data, input_spec);
if (!status.ok()) {
return status;
}
return absl::WrapUnique(
new WillowInputExplicitEncoder(input_data, group_by_data, input_spec));
}
};

} // namespace willow
} // namespace secure_aggregation

#endif // SECURE_AGGREGATION_WILLOW_SRC_INPUT_ENCODING_WILLOW_ENCODER_FACTORY_H_
Loading
Loading