Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ cc_library(
"//common/ast:expr",
"//internal:status_macros",
"//parser:macro",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
Expand Down
33 changes: 18 additions & 15 deletions checker/internal/type_check_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,15 @@ class TypeCheckEnv {
descriptor_pool)
: descriptor_pool_(std::move(descriptor_pool)),
container_(""),
parent_(nullptr) {};
parent_(nullptr) {}

TypeCheckEnv(absl::Nonnull<std::shared_ptr<const google::protobuf::DescriptorPool>>
descriptor_pool,
std::shared_ptr<google::protobuf::Arena> arena)
: descriptor_pool_(std::move(descriptor_pool)),
arena_(std::move(arena)),
container_(""),
parent_(nullptr) {}

// Move-only.
TypeCheckEnv(TypeCheckEnv&&) = default;
Expand All @@ -110,14 +118,19 @@ class TypeCheckEnv {

const absl::optional<Type>& expected_type() const { return expected_type_; }

absl::Span<const std::unique_ptr<TypeIntrospector>> type_providers() const {
absl::Span<const std::shared_ptr<const TypeIntrospector>> type_providers()
const {
return type_providers_;
}

void AddTypeProvider(std::unique_ptr<TypeIntrospector> provider) {
type_providers_.push_back(std::move(provider));
}

void AddTypeProvider(std::shared_ptr<const TypeIntrospector> provider) {
type_providers_.push_back(std::move(provider));
}

const absl::flat_hash_map<std::string, VariableDecl>& variables() const {
return variables_;
}
Expand Down Expand Up @@ -179,17 +192,6 @@ class TypeCheckEnv {
return descriptor_pool_.get();
}

// Return an arena that can be used to allocate memory for types that will be
// used by the TypeChecker being built.
//
// This is only intended to be used for configuration.
google::protobuf::Arena* ABSL_NONNULL arena() {
if (arena_ == nullptr) {
arena_ = std::make_unique<google::protobuf::Arena>();
}
return arena_.get();
}

private:
explicit TypeCheckEnv(const TypeCheckEnv* ABSL_NONNULL parent)
: descriptor_pool_(parent->descriptor_pool_),
Expand All @@ -200,7 +202,8 @@ class TypeCheckEnv {
absl::string_view type, absl::string_view value) const;

ABSL_NONNULL std::shared_ptr<const google::protobuf::DescriptorPool> descriptor_pool_;
ABSL_NULLABLE std::unique_ptr<google::protobuf::Arena> arena_;
// If set, an arena was needed to allocate types in the environment.
ABSL_NULLABLE std::shared_ptr<const google::protobuf::Arena> arena_;
std::string container_;
const TypeCheckEnv* ABSL_NULLABLE parent_;

Expand All @@ -209,7 +212,7 @@ class TypeCheckEnv {
absl::flat_hash_map<std::string, FunctionDecl> functions_;

// Type providers for custom types.
std::vector<std::unique_ptr<TypeIntrospector>> type_providers_;
std::vector<std::shared_ptr<const TypeIntrospector>> type_providers_;

absl::optional<Type> expected_type_;
};
Expand Down
187 changes: 121 additions & 66 deletions checker/internal/type_checker_builder_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
#include <utility>
#include <vector>

#include "absl/algorithm/container.h"
#include "absl/base/no_destructor.h"
#include "absl/base/nullability.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -80,38 +80,126 @@ absl::Status CheckStdMacroOverlap(const FunctionDecl& decl) {
return absl::OkStatus();
}

} // namespace

absl::Status TypeCheckerBuilderImpl::AddContextDeclarationVariables(
const google::protobuf::Descriptor* ABSL_NONNULL descriptor) {
absl::Status AddContextDeclarationVariables(
const google::protobuf::Descriptor* ABSL_NONNULL descriptor, TypeCheckEnv& env) {
for (int i = 0; i < descriptor->field_count(); i++) {
const google::protobuf::FieldDescriptor* proto_field = descriptor->field(i);
MessageTypeField cel_field(proto_field);
cel_field.name();
Type field_type = cel_field.GetType();
if (field_type.IsEnum()) {
field_type = IntType();
}
if (!env_.InsertVariableIfAbsent(
MakeVariableDecl(std::string(cel_field.name()), field_type))) {
if (!env.InsertVariableIfAbsent(
MakeVariableDecl(cel_field.name(), field_type))) {
return absl::AlreadyExistsError(
absl::StrCat("variable '", cel_field.name(),
"' already exists (from context declaration: '",
"' declared multiple times (from context declaration: '",
descriptor->full_name(), "')"));
}
}

return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<TypeChecker>>
TypeCheckerBuilderImpl::Build() && {
for (const auto* type : context_types_) {
CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(type));
absl::StatusOr<FunctionDecl> MergeFunctionDecls(
const FunctionDecl& existing_decl, const FunctionDecl& new_decl) {
if (existing_decl.name() != new_decl.name()) {
return absl::InternalError(
"Attempted to merge function decls with different names");
}

FunctionDecl merged_decl = existing_decl;
for (const auto& ovl : new_decl.overloads()) {
// We do not tolerate signature collisions, even if they are exact matches.
CEL_RETURN_IF_ERROR(merged_decl.AddOverload(ovl));
}

return merged_decl;
}

} // namespace

absl::Status TypeCheckerBuilderImpl::BuildLibraryConfig(
const CheckerLibrary& library,
TypeCheckerBuilderImpl::ConfigRecord* config) {
target_config_ = config;
absl::Cleanup reset([this] { target_config_ = &default_config_; });

return library.configure(*this);
}

absl::Status TypeCheckerBuilderImpl::ApplyConfig(
TypeCheckerBuilderImpl::ConfigRecord config, TypeCheckEnv& env) {
using FunctionDeclRecord = TypeCheckerBuilderImpl::FunctionDeclRecord;

for (auto& type_provider : config.type_providers) {
env.AddTypeProvider(std::move(type_provider));
}

// TODO: check for subsetter
for (FunctionDeclRecord& fn : config.functions) {
switch (fn.add_semantic) {
case AddSemantic::kInsertIfAbsent: {
std::string name = fn.decl.name();
if (!env.InsertFunctionIfAbsent(std::move(fn.decl))) {
return absl::AlreadyExistsError(
absl::StrCat("function '", name, "' declared multiple times"));
}
break;
}
case AddSemantic::kTryMerge:
const FunctionDecl* existing_decl = env.LookupFunction(fn.decl.name());
FunctionDecl to_add = std::move(fn.decl);
if (existing_decl != nullptr) {
CEL_ASSIGN_OR_RETURN(to_add,
MergeFunctionDecls(*existing_decl, to_add));
}
env.InsertOrReplaceFunction(std::move(to_add));
break;
}
}

for (const google::protobuf::Descriptor* context_type : config.context_types) {
CEL_RETURN_IF_ERROR(AddContextDeclarationVariables(context_type, env));
}

for (VariableDecl& var : config.variables) {
if (!env.InsertVariableIfAbsent(var)) {
return absl::AlreadyExistsError(
absl::StrCat("variable '", var.name(), "' declared multiple times"));
}
}

return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<TypeChecker>> TypeCheckerBuilderImpl::Build() {
TypeCheckEnv env(descriptor_pool_, arena_);
env.set_container(container_);
if (expected_type_.has_value()) {
env.set_expected_type(*expected_type_);
}

ConfigRecord anonymous_config;
std::vector<ConfigRecord> configs;
for (const auto& library : libraries_) {
ConfigRecord* config = &anonymous_config;
if (!library.id.empty()) {
configs.emplace_back();
config = &configs.back();
}
CEL_RETURN_IF_ERROR(BuildLibraryConfig(library, config));
}

for (const ConfigRecord& config : configs) {
CEL_RETURN_IF_ERROR(ApplyConfig(std::move(config), env));
}
CEL_RETURN_IF_ERROR(ApplyConfig(std::move(anonymous_config), env));

CEL_RETURN_IF_ERROR(ApplyConfig(default_config_, env));

auto checker = std::make_unique<checker_internal::TypeCheckerImpl>(
std::move(env_), options_);
std::move(env), options_);
return checker;
}

Expand All @@ -123,99 +211,66 @@ absl::Status TypeCheckerBuilderImpl::AddLibrary(CheckerLibrary library) {
if (!library.configure) {
return absl::OkStatus();
}
absl::Status status = library.configure(*this);

libraries_.push_back(std::move(library));
return status;
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddVariable(const VariableDecl& decl) {
bool inserted = env_.InsertVariableIfAbsent(decl);
if (!inserted) {
return absl::AlreadyExistsError(
absl::StrCat("variable '", decl.name(), "' already exists"));
}
target_config_->variables.push_back(std::move(decl));
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddContextDeclaration(
absl::string_view type) {
CEL_ASSIGN_OR_RETURN(absl::optional<Type> resolved_type,
env_.LookupTypeName(type));

if (!resolved_type.has_value()) {
const google::protobuf::Descriptor* desc =
descriptor_pool_->FindMessageTypeByName(type);
if (desc == nullptr) {
return absl::NotFoundError(
absl::StrCat("context declaration '", type, "' not found"));
}

if (!resolved_type->IsStruct()) {
if (IsWellKnownMessageType(desc)) {
return absl::InvalidArgumentError(
absl::StrCat("context declaration '", type, "' is not a struct"));
}

if (!resolved_type->AsStruct()->IsMessage()) {
return absl::InvalidArgumentError(
absl::StrCat("context declaration '", type,
"' is not protobuf message backed struct"));
}

const google::protobuf::Descriptor* descriptor =
&(**(resolved_type->AsStruct()->AsMessage()));

if (absl::c_linear_search(context_types_, descriptor)) {
return absl::AlreadyExistsError(
absl::StrCat("context declaration '", type, "' already exists"));
for (const auto* context_type : target_config_->context_types) {
if (context_type->full_name() == desc->full_name()) {
return absl::AlreadyExistsError(
absl::StrCat("context declaration '", type, "' already exists"));
}
}

context_types_.push_back(descriptor);
target_config_->context_types.push_back(desc);
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::AddFunction(const FunctionDecl& decl) {
CEL_RETURN_IF_ERROR(CheckStdMacroOverlap(decl));
bool inserted = env_.InsertFunctionIfAbsent(decl);
if (!inserted) {
return absl::AlreadyExistsError(
absl::StrCat("function '", decl.name(), "' already exists"));
}
target_config_->functions.push_back(
{std::move(decl), AddSemantic::kInsertIfAbsent});
return absl::OkStatus();
}

absl::Status TypeCheckerBuilderImpl::MergeFunction(const FunctionDecl& decl) {
const FunctionDecl* existing = env_.LookupFunction(decl.name());
if (existing == nullptr) {
return AddFunction(decl);
}

CEL_RETURN_IF_ERROR(CheckStdMacroOverlap(decl));

FunctionDecl merged = *existing;

for (const auto& overload : decl.overloads()) {
if (!merged.AddOverload(overload).ok()) {
return absl::AlreadyExistsError(
absl::StrCat("function '", decl.name(),
"' already has overload that conflicts with overload ''",
overload.id(), "'"));
}
}

env_.InsertOrReplaceFunction(std::move(merged));

target_config_->functions.push_back(
{std::move(decl), AddSemantic::kTryMerge});
return absl::OkStatus();
}

void TypeCheckerBuilderImpl::AddTypeProvider(
std::unique_ptr<TypeIntrospector> provider) {
env_.AddTypeProvider(std::move(provider));
target_config_->type_providers.push_back(std::move(provider));
}

void TypeCheckerBuilderImpl::set_container(absl::string_view container) {
env_.set_container(std::string(container));
container_ = container;
}

void TypeCheckerBuilderImpl::SetExpectedType(const Type& type) {
env_.set_expected_type(type);
expected_type_ = type;
}

} // namespace cel::checker_internal
Loading