diff --git a/base/BUILD b/base/BUILD index 61633b38f..a22e7f742 100644 --- a/base/BUILD +++ b/base/BUILD @@ -10,3 +10,17 @@ cc_library( "status_macros.h", ], ) + +cc_library( + name = "unilib", + srcs = [ + "unilib.cc", + ], + hdrs = [ + "unilib.h", + ], + deps = [ + "@com_github_google_flatbuffers//:flatbuffers", + "@com_google_absl//absl/strings", + ], +) diff --git a/base/unilib.cc b/base/unilib.cc new file mode 100644 index 000000000..6355af23e --- /dev/null +++ b/base/unilib.cc @@ -0,0 +1,19 @@ +#include "base/unilib.h" + +#include "flatbuffers/util.h" + +namespace UniLib { + +// Detects whether a string is valid UTF-8. +bool IsStructurallyValid(absl::string_view str) { + const char *s = &str[0]; + const char *const sEnd = s + str.length(); + while (s < sEnd) { + if (flatbuffers::FromUTF8(&s) < 0) { + return false; + } + } + return true; +} + +} // namespace UniLib diff --git a/base/unilib.h b/base/unilib.h new file mode 100644 index 000000000..3eb1e4958 --- /dev/null +++ b/base/unilib.h @@ -0,0 +1,29 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_BASE_UNILIB_H_ +#define THIRD_PARTY_CEL_CPP_BASE_UNILIB_H_ + +#include "absl/strings/string_view.h" + +namespace UniLib { + +// Detects whether a string is valid UTF-8. +bool IsStructurallyValid(absl::string_view str); + +} // namespace UniLib + +#endif // THIRD_PARTY_CEL_CPP_BASE_UNILIB_H_ diff --git a/common/type.cc b/common/type.cc index ad9fa0ddc..81ff9b1e6 100644 --- a/common/type.cc +++ b/common/type.cc @@ -72,7 +72,7 @@ UnrecognizedType::UnrecognizedType(absl::string_view full_name) : string_rep_(absl::StrCat("type(\"", full_name, "\")")), hash_code_(internal::Hash(full_name)) { assert(google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - std::string(full_name)) == nullptr); + full_name.data()) == nullptr); } absl::string_view UnrecognizedType::full_name() const { @@ -89,7 +89,7 @@ Type::Type(const std::string& full_name) auto obj_desc = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - full_name); + full_name.data()); if (obj_desc != nullptr) { data_ = ObjectType(obj_desc); return; diff --git a/common/value.h b/common/value.h index 777021a4b..e0b97ba37 100644 --- a/common/value.h +++ b/common/value.h @@ -330,7 +330,7 @@ class Container : public SharedValue { } template static Value GetValue(V&& value) { - return Value::From(std::move(value)); + return Value::From(std::forward(value)); } private: diff --git a/conformance/BUILD b/conformance/BUILD index 6d50aaed4..b357cef48 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -19,9 +19,8 @@ ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/macros.textproto", "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", - # TODO(issues/92): Support for parse-only proto message creation within a container. - # "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", - # "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", + "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", + "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", "@com_google_cel_spec//tests/simple:testdata/string.textproto", "@com_google_cel_spec//tests/simple:testdata/timestamps.textproto", "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", @@ -91,18 +90,11 @@ cc_binary( # uncommented when the spec changes to truncation rather than rounding. "--skip_test=conversions/int/double_nearest,double_nearest_neg,double_half_away_neg,double_half_away_pos", "--skip_test=conversions/uint/double_nearest,double_nearest_int,double_half_away", - # TODO(issues/82): Unexpected behavior when converting invalid bytes to string. - "--skip_test=conversions/string/bytes_invalid", - # TODO(issues/83): Missing type() conversion functions - "--skip_test=conversions/type", # TODO(issues/96): Well-known type conversion support. "--skip_test=proto2/literal_wellknown", "--skip_test=proto3/literal_wellknown", "--skip_test=proto2/empty_field/wkt", "--skip_test=proto3/empty_field/wkt", - # TODO(issues/92): Support for parse-only proto message creation within a container. - "--skip_test=proto2/has/undefined", - "--skip_test=proto3/has/undefined", # Requires container support "--skip_test=namespace/namespace/self_eval_container_lookup_unchecked", "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", diff --git a/conformance/server.cc b/conformance/server.cc index 5022c668a..8b9ddac35 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -25,8 +25,6 @@ #include "proto/test/v1/proto3/test_all_types.pb.h" -using absl::Status; -using absl::StatusCode; using ::google::protobuf::Arena; using ::google::protobuf::util::JsonStringToMessage; using ::google::protobuf::util::MessageToJsonString; @@ -42,10 +40,10 @@ class ConformanceServiceImpl { public: explicit ConformanceServiceImpl(std::unique_ptr builder) : builder_(std::move(builder)), - proto2Tests_(&google::api::expr::test::v1::proto2::TestAllTypes:: - default_instance()), - proto3Tests_(&google::api::expr::test::v1::proto3::TestAllTypes:: - default_instance()) {} + proto2_tests_(&google::api::expr::test::v1::proto2::TestAllTypes:: + default_instance()), + proto3_tests_(&google::api::expr::test::v1::proto3::TestAllTypes:: + default_instance()) {} void Parse(const v1alpha1::ParseRequest* request, v1alpha1::ParseResponse* response) { @@ -63,7 +61,7 @@ class ConformanceServiceImpl { } else { google::api::expr::v1alpha1::ParsedExpr out; (out).MergeFrom(parse_status.value()); - response->mutable_parsed_expr()->CopyFrom(out); + *response->mutable_parsed_expr() = out; } } @@ -87,6 +85,7 @@ class ConformanceServiceImpl { google::api::expr::v1alpha1::SourceInfo source_info; google::api::expr::v1alpha1::Expr out; (out).MergeFrom(*expr); + builder_->set_container(request->container()); auto cel_expression_status = builder_->CreateExpression(&out, &source_info); if (!cel_expression_status.ok()) { @@ -144,13 +143,14 @@ class ConformanceServiceImpl { private: std::unique_ptr builder_; - const google::api::expr::test::v1::proto2::TestAllTypes* proto2Tests_; - const google::api::expr::test::v1::proto3::TestAllTypes* proto3Tests_; + const google::api::expr::test::v1::proto2::TestAllTypes* proto2_tests_; + const google::api::expr::test::v1::proto3::TestAllTypes* proto3_tests_; }; int RunServer(bool optimize) { google::protobuf::Arena arena; InterpreterOptions options; + options.enable_qualified_type_identifiers = true; if (optimize) { std::cerr << "Enabling optimizations" << std::endl; @@ -160,14 +160,15 @@ int RunServer(bool optimize) { std::unique_ptr builder = CreateCelExpressionBuilder(options); - builder->AddResolvableEnum( + auto type_registry = builder->GetTypeRegistry(); + type_registry->Register( google::api::expr::test::v1::proto2::GlobalEnum_descriptor()); - builder->AddResolvableEnum( + type_registry->Register( google::api::expr::test::v1::proto3::GlobalEnum_descriptor()); - builder->AddResolvableEnum(google::api::expr::test::v1::proto2::TestAllTypes:: - NestedEnum_descriptor()); - builder->AddResolvableEnum(google::api::expr::test::v1::proto3::TestAllTypes:: - NestedEnum_descriptor()); + type_registry->Register(google::api::expr::test::v1::proto2::TestAllTypes:: + NestedEnum_descriptor()); + type_registry->Register(google::api::expr::test::v1::proto3::TestAllTypes:: + NestedEnum_descriptor()); auto register_status = RegisterBuiltinFunctions(builder->GetRegistry()); if (!register_status.ok()) { std::cerr << "Failed to initialize: " << register_status.ToString() diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 8e654c3bf..8de2fd32b 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -17,6 +17,7 @@ cc_library( deps = [ ":constant_folding", ":qualified_reference_resolver", + ":resolver", "//base:status_macros", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", @@ -30,6 +31,7 @@ cc_library( "//eval/eval:jump_step", "//eval/eval:logic_step", "//eval/eval:select_step", + "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", "//eval/public:ast_traverse", "//eval/public:ast_visitor", @@ -147,6 +149,7 @@ cc_library( "qualified_reference_resolver.h", ], deps = [ + ":resolver", "//base:status_macros", "//eval/eval:const_value_step", "//eval/eval:expression_build_warning", @@ -162,6 +165,21 @@ cc_library( ], ) +cc_library( + name = "resolver", + srcs = ["resolver.cc"], + hdrs = ["resolver.h"], + deps = [ + "//eval/public:cel_builtins", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//eval/public:cel_value", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "qualified_reference_resolver_test", srcs = [ @@ -174,6 +192,7 @@ cc_test( "//eval/public:cel_builtins", "//eval/public:cel_function", "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", @@ -203,3 +222,19 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_test( + name = "resolver_test", + size = "small", + srcs = ["resolver_test.cc"], + deps = [ + ":resolver", + "//base:status_macros", + "//eval/public:cel_function", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//eval/testutil:test_message_cc_proto", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 78928580d..4e52d76a6 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -11,6 +11,7 @@ #include "absl/strings/string_view.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/qualified_reference_resolver.h" +#include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" #include "eval/eval/const_value_step.h" #include "eval/eval/container_access_step.h" @@ -23,6 +24,7 @@ #include "eval/eval/jump_step.h" #include "eval/eval/logic_step.h" #include "eval/eval/select_step.h" +#include "eval/eval/shadowable_value_step.h" #include "eval/eval/ternary_step.h" #include "eval/public/ast_traverse.h" #include "eval/public/ast_visitor.h" @@ -149,59 +151,20 @@ class ComprehensionVisitor : public CondVisitor { class FlatExprVisitor : public AstVisitor { public: FlatExprVisitor( - const CelFunctionRegistry* function_registry, ExecutionPath* path, - bool short_circuiting, - const std::set& enums, - absl::string_view container, + const Resolver& resolver, ExecutionPath* path, bool short_circuiting, const absl::flat_hash_map& constant_idents, bool enable_comprehension, BuilderWarnings* warnings, std::set* iter_variable_names) - : flattened_path_(path), + : resolver_(resolver), + flattened_path_(path), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), - function_registry_(function_registry), short_circuiting_(short_circuiting), constant_idents_(constant_idents), enable_comprehension_(enable_comprehension), builder_warnings_(warnings), iter_variable_names_(iter_variable_names) { GOOGLE_CHECK(iter_variable_names_); - - auto container_elements = absl::StrSplit(container, '.'); - - // Build list of prefixes from container. Non-empty prefixes must end with - // ".", otherwise prefix "abc.xy" will match "abc.xyz.EnumName". - std::string prefix = ""; - std::vector prefixes; - prefixes.push_back(prefix); - for (const auto& elem : container_elements) { - absl::StrAppend(&prefix, elem, "."); - prefixes.push_back(prefix); - } - - for (const auto& prefix : prefixes) { - for (auto enum_desc : enums) { - absl::string_view enum_name = enum_desc->full_name(); - if (!absl::StartsWith(enum_name, prefix)) { - continue; - } - - auto remainder = absl::StripPrefix(enum_name, prefix); - for (int i = 0; i < enum_desc->value_count(); i++) { - auto value_desc = enum_desc->value(i); - if (value_desc) { - // "prefixes" container is ascending-ordered. As such, we will be - // assigning enum reference to the deepest available. - // E.g. if both a.b.c.Name and a.b.Name are available, and - // we try to reference "Name" with the scope of "a.b.c", - // it will be resolved to "a.b.c.Name". - auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", - value_desc->name()); - enum_map_[key] = value_desc; - } - } - } - } } void PostVisitConst(const Constant* const_expr, const Expr* expr, @@ -227,7 +190,7 @@ class FlatExprVisitor : public AstVisitor { return; } - std::string path(ident_expr->name()); + const std::string& path = ident_expr->name(); // Automatically replace constant idents with the backing CEL values. auto constant = constant_idents_.find(path); @@ -236,41 +199,38 @@ class FlatExprVisitor : public AstVisitor { return; } - // Attempt to resolve the identifier as an enum. - const google::protobuf::EnumValueDescriptor* value_desc = nullptr; - auto it = enum_map_.find(path); - if (it != enum_map_.end()) { - value_desc = it->second; - } - - // Attempt to resolve a parsed-only expression as a namespaced identifier. + // Attempt to resolve a select expression as a namespaced identifier for an + // enum or type constant value. + absl::optional const_value = absl::nullopt; while (!namespace_stack_.empty()) { - const auto& select_node = namespace_stack_.back(); + const auto& select_node = namespace_stack_.front(); // Generate path in format ".....". - absl::StrAppend(&path, ".", select_node.second); - namespace_map_[select_node.first] = path; - - // Attempt to match namespace - auto it = enum_map_.find(path); - if (it != enum_map_.end()) { - resolved_select_expr_ = select_node.first; - value_desc = it->second; - } - namespace_stack_.pop_back(); - } - - if (resolved_select_expr_) { - if (!resolved_select_expr_->has_select_expr()) { - progress_status_ = absl::InternalError("Unexpected Expr type"); + auto select_expr = select_node.first; + auto qualified_path = absl::StrCat(path, ".", select_node.second); + namespace_map_[select_expr] = qualified_path; + + // Attempt to find a constant enum or type value which matches the + // qualified path present in the expression. Whether the identifier + // can be resolved to a type instance depends on whether the option to + // 'enable_qualified_type_identifiers' is set to true. + const_value = resolver_.FindConstant(qualified_path, select_expr->id()); + if (const_value.has_value()) { + AddStep(CreateShadowableValueStep(qualified_path, const_value.value(), + select_expr->id())); + resolved_select_expr_ = select_expr; + namespace_stack_.clear(); return; } - AddStep(CreateConstValueStep(value_desc, resolved_select_expr_->id())); - return; + namespace_stack_.pop_front(); } - if (value_desc) { - AddStep(CreateConstValueStep(value_desc, expr->id())); + + // Attempt to resolve a simple identifier as an enum or type constant value. + const_value = resolver_.FindConstant(path, expr->id()); + if (const_value.has_value()) { + AddStep(CreateShadowableValueStep(path, const_value.value(), expr->id())); return; } + AddStep(CreateIdentStep(ident_expr, expr->id())); } @@ -283,8 +243,24 @@ class FlatExprVisitor : public AstVisitor { // select_expr. // Chain of multiple SELECT ending with IDENT can represent namespaced // entity. - if (select_expr->operand().has_ident_expr() || - select_expr->operand().has_select_expr()) { + if (!select_expr->test_only() && + (select_expr->operand().has_ident_expr() || + select_expr->operand().has_select_expr())) { + // select expressions are pushed in reverse order: + // google.type.Expr is pushed as: + // - field: 'Expr' + // - field: 'type' + // - id: 'google' + // + // The search order though is as follows: + // - id: 'google.type.Expr' + // - id: 'google.type', field: 'Expr' + // - id: 'google', field: 'type', field: 'Expr' + for (size_t i = 0; i < namespace_stack_.size(); i++) { + auto ns = namespace_stack_[i]; + namespace_stack_[i] = { + ns.first, absl::StrCat(select_expr->field(), ".", ns.second)}; + } namespace_stack_.push_back({expr, select_expr->field()}); } else { namespace_stack_.clear(); @@ -311,7 +287,6 @@ class FlatExprVisitor : public AstVisitor { } std::string select_path = ""; - auto it = namespace_map_.find(expr); if (it != namespace_map_.end()) { select_path = it->second; @@ -364,16 +339,47 @@ class FlatExprVisitor : public AstVisitor { if (cond_visitor) { cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); - } else { - // Special case for "_[_]". - if (call_expr->function() == builtin::kIndex) { - AddStep(CreateContainerAccessStep(call_expr, expr->id())); + return; + } + + // Special case for "_[_]". + if (call_expr->function() == builtin::kIndex) { + AddStep(CreateContainerAccessStep(call_expr, expr->id())); + return; + } + + // Establish the search criteria for a given function. + absl::string_view function = call_expr->function(); + bool receiver_style = call_expr->has_target(); + size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); + auto arguments_matcher = ArgumentsMatcher(num_args); + + // First, search for lazily defined function overloads. + // Lazy functions shadow eager functions with the same signature. + auto lazy_overloads = resolver_.FindLazyOverloads( + function, receiver_style, arguments_matcher, expr->id()); + if (!lazy_overloads.empty()) { + AddStep(CreateFunctionStep(call_expr, expr->id(), lazy_overloads)); + return; + } + + // Second, search for eagerly defined function overloads. + auto overloads = resolver_.FindOverloads(function, receiver_style, + arguments_matcher, expr->id()); + if (overloads.empty()) { + // Create a warning that the overload could not be found. Depending on the + // builder_warnings configuration, this could result in termination of the + // CelExpression creation or an inspectable warning for use within runtime + // logging. + auto status = builder_warnings_->AddWarning( + absl::Status(absl::StatusCode::kInvalidArgument, + "No overloads provided for FunctionStep creation")); + if (!status.ok()) { + SetProgressStatusError(status); return; } - // For regular functions, just create one based on registry. - AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_, - builder_warnings_)); } + AddStep(CreateFunctionStep(call_expr, expr->id(), overloads)); } void PreVisitComprehension(const Comprehension*, const Expr* expr, @@ -445,7 +451,27 @@ class FlatExprVisitor : public AstVisitor { if (!progress_status_.ok()) { return; } - AddStep(CreateCreateStructStep(struct_expr, expr->id())); + + // If the message name is empty, this signals that a map should be created. + auto message_name = struct_expr->message_name(); + if (message_name.empty()) { + AddStep(CreateCreateStructStep(struct_expr, expr->id())); + return; + } + + // If the message name is not empty, then the message name must be resolved + // within the container, and if a descriptor is found, then a proto message + // creation step will be created. + auto message_desc = resolver_.FindDescriptor(message_name, expr->id()); + if (message_desc != nullptr) { + AddStep(CreateCreateStructStep(struct_expr, message_desc, expr->id())); + return; + } + + // Otherwise, the message descriptor was not linked into the binary. + SetProgressStatusError(absl::InvalidArgumentError( + "Error configuring message creation: message descriptor not found: " + + message_name)); } absl::Status progress_status() const { return progress_status_; } @@ -484,6 +510,7 @@ class FlatExprVisitor : public AstVisitor { } private: + const Resolver& resolver_; ExecutionPath* flattened_path_; absl::Status progress_status_; @@ -500,12 +527,6 @@ class FlatExprVisitor : public AstVisitor { // field is used as marker suppressing CelExpression creation for SELECTs. const Expr* resolved_select_expr_; - // Fully resolved enum value names. - absl::node_hash_map - enum_map_; - - const CelFunctionRegistry* function_registry_; - bool short_circuiting_; const absl::flat_hash_map& constant_idents_; @@ -724,6 +745,8 @@ FlatExprBuilder::CreateExpressionImpl( std::vector* warnings) const { ExecutionPath execution_path; BuilderWarnings warnings_builder(fail_on_warnings_); + Resolver resolver(container(), GetRegistry(), GetTypeRegistry(), + enable_qualified_type_identifiers_); if (absl::StartsWith(container(), ".") || absl::EndsWith(container(), ".")) { return absl::InvalidArgumentError( @@ -740,9 +763,8 @@ FlatExprBuilder::CreateExpressionImpl( // available, we can skip the reference resolve step here if it's already // done. if (reference_map != nullptr && !reference_map->empty()) { - absl::StatusOr> rewritten = - ResolveReferences(*effective_expr, *reference_map, *GetRegistry(), - container(), &warnings_builder); + absl::StatusOr> rewritten = ResolveReferences( + *effective_expr, *reference_map, resolver, &warnings_builder); if (!rewritten.ok()) { return rewritten.status(); } @@ -763,9 +785,8 @@ FlatExprBuilder::CreateExpressionImpl( } std::set iter_variable_names; - FlatExprVisitor visitor(this->GetRegistry(), &execution_path, - shortcircuiting_, resolvable_enums(), container(), - idents, enable_comprehension_, &warnings_builder, + FlatExprVisitor visitor(resolver, &execution_path, shortcircuiting_, idents, + enable_comprehension_, &warnings_builder, &iter_variable_names); AstTraverse(effective_expr, source_info, &visitor); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 3b18ba9da..970c8a7ff 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -24,7 +24,8 @@ class FlatExprBuilder : public CelExpressionBuilder { constant_arena_(nullptr), enable_comprehension_(true), comprehension_max_iterations_(0), - fail_on_warnings_(true) {} + fail_on_warnings_(true), + enable_qualified_type_identifiers_(false) {} // set_enable_unknowns controls support for unknowns in expressions created. void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } @@ -65,6 +66,12 @@ class FlatExprBuilder : public CelExpressionBuilder { fail_on_warnings_ = should_fail; } + // set_enable_qualified_type_identifiers controls whether select expressions + // may be treated as constant type identifiers during CelExpression creation. + void set_enable_qualified_type_identifiers(bool enabled) { + enable_qualified_type_identifiers_ = enabled; + } + absl::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; @@ -98,6 +105,7 @@ class FlatExprBuilder : public CelExpressionBuilder { bool enable_comprehension_; int comprehension_max_iterations_; bool fail_on_warnings_; + bool enable_qualified_type_identifiers_; }; } // namespace runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2a65d2bc3..0a1fcf258 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -146,7 +146,6 @@ TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { CelValue result = eval_status.value(); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(result.ErrorOrDie()->message(), Eq("No matching overloads found")); @@ -966,7 +965,7 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); auto build_status = builder.CreateExpression(&expr, &source_info); ASSERT_OK(build_status); @@ -997,7 +996,7 @@ TEST(FlatExprBuilderTest, SimpleEnumIdentTest) { cur_expr->mutable_ident_expr()->set_name(enum_name); FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); auto build_status = builder.CreateExpression(&expr, &source_info); ASSERT_OK(build_status); @@ -1068,8 +1067,8 @@ void EvalExpressionWithEnum(absl::string_view enum_name, cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); - builder.AddResolvableEnum(TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestEnum_descriptor()); builder.set_container(std::string(container)); auto build_status = builder.CreateExpression(&expr, &source_info); diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 819238226..21f39c5af 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -55,81 +55,40 @@ absl::optional ToNamespace(const Expr& expr) { } } -// Shape matcher for CelFunctions. -// TODO(issues/91): this is the same behavior as parsed exprs in the CPP -// evaluator (just check the right call style and number of arguments), but we -// should have enough type information in a checked expr to find a more -// specific candidate list. -std::vector ArgumentMatcher(int argument_count) { - std::vector argument_matcher(argument_count); - for (int i = 0; i < argument_count; i++) { - argument_matcher[i] = CelValue::Type::kAny; - } - return argument_matcher; -} - -bool OverloadExists(const CelFunctionRegistry& registry, absl::string_view name, - const std::vector& argument_matcher, +bool OverloadExists(const Resolver& resolver, absl::string_view name, + const std::vector& arguments_matcher, bool receiver_style = false) { - return !registry.FindOverloads(name, receiver_style, argument_matcher) + return !resolver.FindOverloads(name, receiver_style, arguments_matcher) .empty() || - !registry.FindLazyOverloads(name, receiver_style, argument_matcher) + !resolver.FindLazyOverloads(name, receiver_style, arguments_matcher) .empty(); } -// Handles checking for the most specific (most qualified) function that matches -// the call shape. -class ContainerLookupHelper { - public: - ContainerLookupHelper(absl::string_view container, - const CelFunctionRegistry& registry) - : registry_(registry) { - auto container_elements = absl::StrSplit(container, '.'); - std::string prefix = ""; - namespace_prefixes_.push_back(prefix); - for (const auto& elem : container_elements) { - // Tolerate trailing / leading '.'. - if (elem.empty()) { - continue; - } - absl::StrAppend(&prefix, elem, "."); - namespace_prefixes_.insert(namespace_prefixes_.begin(), prefix); - } +// Return the qualified name of the most qualified matching overload, or +// nullopt if no matches are found. +absl::optional BestOverloadMatch(const Resolver& resolver, + absl::string_view base_name, + int argument_count) { + if (IsSpecialFunction(base_name)) { + return std::string(base_name); } - - // Return the qualified name of the most qualified matching overload, or - // nullopt if no matches are found. - absl::optional BestOverloadMatch(absl::string_view base_name, - int argument_count) { - if (IsSpecialFunction(base_name)) { - return std::string(base_name); + auto arguments_matcher = ArgumentsMatcher(argument_count); + // Check from most qualified to least qualified for a matching overload. + auto names = resolver.FullyQualifiedNames(base_name); + for (auto name = names.begin(); name != names.end(); ++name) { + if (OverloadExists(resolver, *name, arguments_matcher)) { + return *name; } - auto argument_matcher = ArgumentMatcher(argument_count); - // Check from most qualified to least qualified for a matching overload. - for (auto iter = namespace_prefixes_.begin(); - iter != namespace_prefixes_.end(); ++iter) { - std::string resolved_name = absl::StrCat(*iter, base_name); - if (OverloadExists(registry_, resolved_name, argument_matcher)) { - return resolved_name; - } - } - return absl::nullopt; } - - private: - // Namespace prefixes to check in most to least specific order. - std::vector namespace_prefixes_; - const CelFunctionRegistry& registry_; -}; + return absl::nullopt; +} class ReferenceResolver { public: ReferenceResolver(const google::protobuf::Map& reference_map, - const CelFunctionRegistry& registry, - BuilderWarnings* warnings, absl::string_view container) + const Resolver& resolver, BuilderWarnings* warnings) : reference_map_(reference_map), - container_lookup_helper_(container, registry), - registry_(registry), + resolver_(resolver), warnings_(warnings) {} // Attempt to resolve references in expr. Return true if part of the @@ -247,8 +206,7 @@ class ReferenceResolver { std::string resolved_name = absl::StrCat(maybe_namespace.value(), ".", call_expr->function()); auto maybe_resolved_function = - container_lookup_helper_.BestOverloadMatch(resolved_name, - arg_num); + BestOverloadMatch(resolver_, resolved_name, arg_num); if (maybe_resolved_function.has_value()) { call_expr->set_function(maybe_resolved_function.value()); call_expr->clear_target(); @@ -259,8 +217,8 @@ class ReferenceResolver { } else { // Not a receiver style function call. Check to see if it is a namespaced // function using a shorthand inside the expression container. - auto maybe_resolved_function = container_lookup_helper_.BestOverloadMatch( - call_expr->function(), arg_num); + auto maybe_resolved_function = + BestOverloadMatch(resolver_, call_expr->function(), arg_num); if (!maybe_resolved_function.has_value()) { RETURN_IF_ERROR(warnings_->AddWarning(absl::InvalidArgumentError( absl::StrCat("No overload found in reference resolve step for ", @@ -273,8 +231,8 @@ class ReferenceResolver { // For parity, if we didn't rewrite the receiver call style function, // check that an overload is provided in the builder. if (call_expr->has_target() && - !OverloadExists(registry_, call_expr->function(), - ArgumentMatcher(arg_num + 1), + !OverloadExists(resolver_, call_expr->function(), + ArgumentsMatcher(arg_num + 1), /* receiver_style= */ true)) { RETURN_IF_ERROR(warnings_->AddWarning(absl::InvalidArgumentError( absl::StrCat("No overload found in reference resolve step for ", @@ -353,8 +311,7 @@ class ReferenceResolver { } const google::protobuf::Map& reference_map_; - ContainerLookupHelper container_lookup_helper_; - const CelFunctionRegistry& registry_; + const Resolver& resolver_; BuilderWarnings* warnings_; }; @@ -362,11 +319,10 @@ class ReferenceResolver { absl::StatusOr> ResolveReferences( const Expr& expr, const google::protobuf::Map& reference_map, - const CelFunctionRegistry& registry, absl::string_view container, - BuilderWarnings* warnings) { + const Resolver& resolver, BuilderWarnings* warnings) { Expr out(expr); - ReferenceResolver resolver(reference_map, registry, warnings, container); - absl::StatusOr rewrite_result = resolver.Rewrite(&out); + ReferenceResolver ref_resolver(reference_map, resolver, warnings); + absl::StatusOr rewrite_result = ref_resolver.Rewrite(&out); if (!rewrite_result.ok()) { return rewrite_result.status(); } else if (rewrite_result.value()) { diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index fe8bd8b54..069653412 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -7,8 +7,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "eval/compiler/resolver.h" #include "eval/eval/expression_build_warning.h" -#include "eval/public/cel_function_registry.h" namespace google { namespace api { @@ -23,8 +23,7 @@ namespace runtime { absl::StatusOr> ResolveReferences( const google::api::expr::v1alpha1::Expr& expr, const google::protobuf::Map& reference_map, - const CelFunctionRegistry& registry, absl::string_view container, - BuilderWarnings* warnings); + const Resolver& resolver, BuilderWarnings* warnings); } // namespace runtime } // namespace expr diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 695aef4e3..e36630b2d 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -9,6 +9,7 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" #include "testutil/util.h" #include "base/status_macros.h" @@ -84,8 +85,11 @@ TEST(ResolveReferences, Basic) { reference_map[2].set_name("foo.bar.var1"); reference_map[5].set_name("bar.foo.var2"); BuilderWarnings warnings; - CelFunctionRegistry registry; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -106,8 +110,11 @@ TEST(ResolveReferences, ReturnsNulloptIfNoChanges) { Expr expr = ParseTestProto(kExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); } @@ -116,11 +123,13 @@ TEST(ResolveReferences, NamespacedIdent) { Expr expr = ParseTestProto(kExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].set_name("foo.bar.var1"); reference_map[7].set_name("namespace_x.bar"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -168,9 +177,12 @@ TEST(ResolveReferences, WarningOnPresenceTest) { })"); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].set_name("foo.bar.var1"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT( @@ -210,13 +222,16 @@ constexpr char kEnumExpr[] = R"( TEST(ResolveReferences, EnumConstReferenceUsed) { Expr expr = ParseTestProto(kEnumExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].set_name("foo.bar.var1"); reference_map[5].set_name("bar.foo.Enum.ENUM_VAL1"); reference_map[5].mutable_value()->set_int64_value(9); BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -236,13 +251,16 @@ TEST(ResolveReferences, EnumConstReferenceUsed) { TEST(ResolveReferences, ConstReferenceSkipped) { Expr expr = ParseTestProto(kExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].set_name("foo.bar.var1"); reference_map[2].mutable_value()->set_bool_value(true); reference_map[5].set_name("bar.foo.var2"); BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -292,16 +310,19 @@ call_expr { TEST(ResolveReferences, FunctionReferenceBasic) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction( CelFunctionDescriptor("boolean_and", false, { CelValue::Type::kBool, CelValue::Type::kBool, }))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); } @@ -309,10 +330,13 @@ TEST(ResolveReferences, FunctionReferenceBasic) { TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), @@ -339,12 +363,14 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { for (const char* builtin_fn : special_builtins) { google::protobuf::Map reference_map; // Builtins aren't in the function registry. - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].add_overload_id(absl::StrCat("builtin.", builtin_fn)); expr.mutable_call_expr()->set_function(builtin_fn); - auto result = - ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), IsEmpty()); @@ -355,10 +381,13 @@ TEST(ResolveReferences, FunctionReferenceMissingOverloadDetectedAndMissingReference) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].set_name("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT( @@ -374,9 +403,12 @@ TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), @@ -405,11 +437,14 @@ TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), IsEmpty()); @@ -420,9 +455,12 @@ TEST(ResolveReferences, Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), @@ -433,11 +471,14 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "ext.boolean_and", false, {CelValue::Type::kBool}))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -456,13 +497,15 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunctionInContainer) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; - BuilderWarnings warnings; - CelFunctionRegistry registry; reference_map[1].add_overload_id("udf_boolean_and"); - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + BuilderWarnings warnings; + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "com.google.ext.boolean_and", false, {CelValue::Type::kBool}))); - auto result = - ResolveReferences(expr, reference_map, registry, "com.google", &warnings); + CelTypeRegistry type_registry; + Resolver registry("com.google", &func_registry, &type_registry); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -507,19 +550,177 @@ TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { Expr expr = ParseTestProto(kReceiverCallHasExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "ext.option.boolean_and", true, {CelValue::Type::kBool}))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); // The target is unchanged because it is a test_only select. EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), IsEmpty()); } +constexpr char kComprehensionExpr[] = R"( +id:17 +comprehension_expr: { + iter_var:"i" + iter_range:{ + id:1 + list_expr:{ + elements:{ + id:2 + const_expr:{int64_value:1} + } + elements:{ + id:3 + ident_expr:{name:"ENUM"} + } + elements:{ + id:4 + const_expr:{int64_value:3} + } + } + } + accu_var:"__result__" + accu_init: { + id:10 + const_expr:{bool_value:false} + } + loop_condition:{ + id:13 + call_expr:{ + function:"@not_strictly_false" + args:{ + id:12 + call_expr:{ + function:"!_" + args:{ + id:11 + ident_expr:{name:"__result__"} + } + } + } + } + } + loop_step:{ + id:15 + call_expr: { + function:"_||_" + args:{ + id:14 + ident_expr: {name:"__result__"} + } + args:{ + id:8 + call_expr:{ + function:"_==_" + args:{ + id:7 ident_expr:{name:"ENUM"} + } + args:{ + id:9 ident_expr:{name:"i"} + } + } + } + } + } + result:{id:16 ident_expr:{name:"__result__"}} +} +)"; +TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { + Expr expr = ParseTestProto(kComprehensionExpr); + google::protobuf::Map reference_map; + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + reference_map[3].set_name("ENUM"); + reference_map[3].mutable_value()->set_int64_value(2); + reference_map[7].set_name("ENUM"); + reference_map[7].mutable_value()->set_int64_value(2); + BuilderWarnings warnings; + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); + ASSERT_OK(result); + EXPECT_THAT(result.value(), Optional(EqualsProto(R"( + id: 17 + comprehension_expr { + iter_var: "i" + iter_range { + id: 1 + list_expr { + elements { + id: 2 + const_expr { int64_value: 1 } + } + elements { + id: 3 + const_expr { int64_value: 2 } + } + elements { + id: 4 + const_expr { int64_value: 3 } + } + } + } + accu_var: "__result__" + accu_init { + id: 10 + const_expr { bool_value: false } + } + loop_condition { + id: 13 + call_expr { + function: "@not_strictly_false" + args { + id: 12 + call_expr { + function: "!_" + args { + id: 11 + ident_expr { name: "__result__" } + } + } + } + } + } + loop_step { + id: 15 + call_expr { + function: "_||_" + args { + id: 14 + ident_expr { name: "__result__" } + } + args { + id: 8 + call_expr { + function: "_==_" + args { + id: 7 + const_expr { int64_value: 2 } + } + args { + id: 9 + ident_expr { name: "i" } + } + } + } + } + } + result { + id: 16 + ident_expr { name: "__result__" } + } + })"))); +} + } // namespace } // namespace runtime diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc new file mode 100644 index 000000000..0a6856e7f --- /dev/null +++ b/eval/compiler/resolver.cc @@ -0,0 +1,164 @@ +#include "eval/compiler/resolver.h" + +#include "google/protobuf/descriptor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/types/optional.h" +#include "eval/public/cel_builtins.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +Resolver::Resolver(absl::string_view container, + const CelFunctionRegistry* function_registry, + const CelTypeRegistry* type_registry, + bool resolve_qualified_type_identifiers) + : namespace_prefixes_(), + enum_value_map_(), + function_registry_(function_registry), + type_registry_(type_registry), + resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) { + // The constructor for the registry determines the set of possible namespace + // prefixes which may appear within the given expression container, and also + // eagerly maps possible enum names to enum values. + + auto container_elements = absl::StrSplit(container, '.'); + std::string prefix = ""; + namespace_prefixes_.push_back(prefix); + for (const auto& elem : container_elements) { + // Tolerate trailing / leading '.'. + if (elem.empty()) { + continue; + } + absl::StrAppend(&prefix, elem, "."); + namespace_prefixes_.insert(namespace_prefixes_.begin(), prefix); + } + + for (const auto& prefix : namespace_prefixes_) { + for (auto enum_desc : type_registry->Enums()) { + absl::string_view enum_name = enum_desc->full_name(); + if (!absl::StartsWith(enum_name, prefix)) { + continue; + } + + auto remainder = absl::StripPrefix(enum_name, prefix); + for (int i = 0; i < enum_desc->value_count(); i++) { + auto value_desc = enum_desc->value(i); + if (value_desc) { + // "prefixes" container is ascending-ordered. As such, we will be + // assigning enum reference to the deepest available. + // E.g. if both a.b.c.Name and a.b.Name are available, and + // we try to reference "Name" with the scope of "a.b.c", + // it will be resolved to "a.b.c.Name". + auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", + value_desc->name()); + enum_value_map_[key] = CelValue::CreateInt64(value_desc->number()); + } + } + } + } +} + +std::vector Resolver::FullyQualifiedNames(absl::string_view name, + int64_t expr_id) const { + // TODO(issues/105): refactor the reference resolution into this method. + // and handle the case where this id is in the reference map as either a + // function name or identifier name. + std::vector names; + // Handle the case where the name contains a leading '.' indicating it is + // already fully-qualified. + if (absl::StartsWith(name, ".")) { + std::string fully_qualified_name = std::string(name.substr(1)); + names.push_back(fully_qualified_name); + return names; + } + + // namespace prefixes is guaranteed to contain at least empty string, so this + // function will always produce at least one result. + for (const auto& prefix : namespace_prefixes_) { + std::string fully_qualified_name = absl::StrCat(prefix, name); + names.push_back(fully_qualified_name); + } + return names; +} + +absl::optional Resolver::FindConstant(absl::string_view name, + int64_t expr_id) const { + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + // Attempt to resolve the fully qualified name to a known enum. + auto enum_entry = enum_value_map_.find(name); + if (enum_entry != enum_value_map_.end()) { + return enum_entry->second; + } + // Conditionally resolve fully qualified names as type values if the option + // to do so is configured in the expression builder. If the type name is + // not qualified, then it too may be returned as a constant value. + if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, ".")) { + auto type_value = type_registry_->FindType(name); + if (type_value.has_value()) { + return type_value.value(); + } + } + } + return absl::nullopt; +} + +std::vector Resolver::FindOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id) const { + // Resolve the fully qualified names and then search the function registry + // for possible matches. + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (auto it = names.begin(); it != names.end(); it++) { + // Only one set of overloads is returned along the namespace hierarchy as + // the function name resolution follows the same behavior as variable name + // resolution, meaning the most specific definition wins. This is different + // from how C++ namespaces work, as they will accumulate the overload set + // over the namespace hierarchy. + funcs = function_registry_->FindOverloads(*it, receiver_style, types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +std::vector Resolver::FindLazyOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id) const { + // Resolve the fully qualified names and then search the function registry + // for possible matches. + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + funcs = function_registry_->FindLazyOverloads(name, receiver_style, types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +const google::protobuf::Descriptor* Resolver::FindDescriptor(absl::string_view name, + int64_t expr_id) const { + // Resolve the fully qualified names and then defer to the type registry + // for possible matches. + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + auto desc = type_registry_->FindDescriptor(name); + if (desc != nullptr) { + return desc; + } + } + return nullptr; +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h new file mode 100644 index 000000000..1f79867f1 --- /dev/null +++ b/eval/compiler/resolver.h @@ -0,0 +1,92 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Resolver assists with finding functions and types within a container. +// +// This class builds on top of the CelFunctionRegistry and CelTypeRegistry by +// layering on the namespace resolution rules of CEL onto the calls provided +// by each of these libraries. +// +// TODO(issues/105): refactor the Resolver to consider CheckedExpr metadata +// for reference resolution. +class Resolver { + public: + Resolver(absl::string_view container, + const CelFunctionRegistry* function_registry, + const CelTypeRegistry* type_registry, + bool resolve_qualified_type_identifiers = true); + + ~Resolver() {} + + // FindConstant will return an enum constant value or a type value if one + // exists for the given name. + // + // Since enums and type identifiers are specified as (potentially) qualified + // names within an expression, there is the chance that the name provided + // is a variable name which happens to collide with an existing enum or proto + // based type name. For this reason, within parsed only expressions, the + // constant should be treated as a value that can be shadowed by a runtime + // provided value. + absl::optional FindConstant(absl::string_view name, + int64_t expr_id) const; + + // FindDescriptor returns the protobuf message descriptor for the given name + // if one exists. + const google::protobuf::Descriptor* FindDescriptor(absl::string_view name, + int64_t expr_id) const; + + // FindLazyOverloads returns the set, possibly empty, of lazy overloads + // matching the given function signature. + std::vector FindLazyOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + + // FindOverloads returns the set, possibly empty, of eager function overloads + // matching the given function signature. + std::vector FindOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + + // FullyQualifiedNames returns the set of fully qualified names which may be + // derived from the base_name within the specified expression container. + std::vector FullyQualifiedNames(absl::string_view base_name, + int64_t expr_id = -1) const; + + private: + std::vector namespace_prefixes_; + absl::flat_hash_map enum_value_map_; + const CelFunctionRegistry* function_registry_; + const CelTypeRegistry* type_registry_; + bool resolve_qualified_type_identifiers_; +}; + +// ArgumentMatcher generates a function signature matcher for CelFunctions. +// TODO(issues/91): this is the same behavior as parsed exprs in the CPP +// evaluator (just check the right call style and number of arguments), but we +// should have enough type information in a checked expr to find a more +// specific candidate list. +inline std::vector ArgumentsMatcher(int argument_count) { + std::vector argument_matcher(argument_count); + for (int i = 0; i < argument_count; i++) { + argument_matcher[i] = CelValue::Type::kAny; + } + return argument_matcher; +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc new file mode 100644 index 000000000..08083925e --- /dev/null +++ b/eval/compiler/resolver_test.cc @@ -0,0 +1,199 @@ +#include "eval/compiler/resolver.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using testing::Eq; + +class FakeFunction : public CelFunction { + public: + explicit FakeFunction(const std::string& name) + : CelFunction(CelFunctionDescriptor{name, false, {}}) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + return absl::OkStatus(); + } +}; + +TEST(ResolverTest, TestFullyQualifiedNames) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr", &func_registry, &type_registry); + + auto names = resolver.FullyQualifiedNames("simple_name"); + std::vector expected_names( + {"google.api.expr.simple_name", "google.api.simple_name", + "google.simple_name", "simple_name"}); + EXPECT_THAT(names, Eq(expected_names)); +} + +TEST(ResolverTest, TestFullyQualifiedNamesPartiallyQualifiedName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr", &func_registry, &type_registry); + + auto names = resolver.FullyQualifiedNames("expr.simple_name"); + std::vector expected_names( + {"google.api.expr.expr.simple_name", "google.api.expr.simple_name", + "google.expr.simple_name", "expr.simple_name"}); + EXPECT_THAT(names, Eq(expected_names)); +} + +TEST(ResolverTest, TestFullyQualifiedNamesAbsoluteName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr", &func_registry, &type_registry); + + auto names = resolver.FullyQualifiedNames(".google.api.expr.absolute_name"); + EXPECT_THAT(names.size(), Eq(1)); + EXPECT_THAT(names[0], Eq("google.api.expr.absolute_name")); +} + +TEST(ResolverTest, TestFindConstantEnum) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + type_registry.Register(TestMessage::TestEnum_descriptor()); + Resolver resolver("google.api.expr.runtime.TestMessage", &func_registry, + &type_registry); + + auto enum_value = resolver.FindConstant("TestEnum.TEST_ENUM_1", -1); + EXPECT_TRUE(enum_value.has_value()); + EXPECT_TRUE(enum_value.value().IsInt64()); + EXPECT_THAT(enum_value.value().Int64OrDie(), Eq(1L)); + + enum_value = resolver.FindConstant( + ".google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_2", -1); + EXPECT_TRUE(enum_value.has_value()); + EXPECT_TRUE(enum_value.value().IsInt64()); + EXPECT_THAT(enum_value.value().Int64OrDie(), Eq(2L)); +} + +TEST(ResolverTest, TestFindConstantUnqualifiedType) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto type_value = resolver.FindConstant("int", -1); + EXPECT_TRUE(type_value.has_value()); + EXPECT_TRUE(type_value.value().IsCelType()); + EXPECT_THAT(type_value.value().CelTypeOrDie().value(), Eq("int")); +} + +TEST(ResolverTest, TestFindConstantFullyQualifiedType) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto type_value = + resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); + EXPECT_TRUE(type_value.has_value()); + EXPECT_TRUE(type_value.value().IsCelType()); + EXPECT_THAT(type_value.value().CelTypeOrDie().value(), + Eq("google.api.expr.runtime.TestMessage")); +} + +TEST(ResolverTest, TestFindConstantQualifiedTypeDisabled) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("", &func_registry, &type_registry, false); + auto type_value = + resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); + EXPECT_FALSE(type_value.has_value()); +} + +TEST(ResolverTest, TestFindDescriptorBySimpleName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); + + auto desc_value = resolver.FindDescriptor("TestMessage", -1); + EXPECT_TRUE(desc_value != nullptr); + EXPECT_THAT(desc_value, Eq(TestMessage::GetDescriptor())); +} + +TEST(ResolverTest, TestFindDescriptorByQualifiedName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); + + auto desc_value = + resolver.FindDescriptor(".google.api.expr.runtime.TestMessage", -1); + EXPECT_TRUE(desc_value != nullptr); + EXPECT_THAT(desc_value, Eq(TestMessage::GetDescriptor())); +} + +TEST(ResolverTest, TestFindDescriptorNotFound) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); + + auto desc_value = resolver.FindDescriptor("UndefinedMessage", -1); + EXPECT_TRUE(desc_value == nullptr); +} + +TEST(ResolverTest, TestFindOverloads) { + CelFunctionRegistry func_registry; + auto status = + func_registry.Register(std::make_unique("fake_func")); + ASSERT_OK(status); + status = func_registry.Register( + std::make_unique("cel.fake_ns_func")); + ASSERT_OK(status); + + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto overloads = + resolver.FindOverloads("fake_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + EXPECT_THAT(overloads[0]->descriptor().name(), Eq("fake_func")); + + overloads = + resolver.FindOverloads("fake_ns_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + EXPECT_THAT(overloads[0]->descriptor().name(), Eq("cel.fake_ns_func")); +} + +TEST(ResolverTest, TestFindLazyOverloads) { + CelFunctionRegistry func_registry; + auto status = func_registry.RegisterLazyFunction( + CelFunctionDescriptor{"fake_lazy_func", false, {}}); + ASSERT_OK(status); + status = func_registry.RegisterLazyFunction( + CelFunctionDescriptor{"cel.fake_lazy_ns_func", false, {}}); + ASSERT_OK(status); + + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto overloads = + resolver.FindLazyOverloads("fake_lazy_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + + overloads = resolver.FindLazyOverloads("fake_lazy_ns_func", false, + ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 9dcfafae4..48d2ea33b 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -60,13 +60,9 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -131,7 +127,6 @@ cc_library( "//eval/public:cel_builtins", "//eval/public:cel_function", "//eval/public:cel_function_provider", - "//eval/public:cel_function_registry", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_function_result_set", @@ -206,6 +201,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -381,6 +377,7 @@ cc_test( "//base:status_macros", "//eval/public:cel_attribute", "//eval/public:cel_function", + "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:unknown_function_result_set", @@ -461,11 +458,13 @@ cc_test( ":create_struct_step", ":ident_step", "//base:status_macros", + "//eval/public:cel_type_registry", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//testutil:util", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -604,3 +603,31 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "shadowable_value_step", + srcs = ["shadowable_value_step.cc"], + hdrs = ["shadowable_value_step.h"], + deps = [ + ":evaluator_core", + ":expression_step_base", + "//eval/public:activation", + "//eval/public:cel_value", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "shadowable_value_step_test", + size = "small", + srcs = ["shadowable_value_step_test.cc"], + deps = [ + ":evaluator_core", + ":shadowable_value_step", + "//base:status_macros", + "//eval/public:cel_value", + "@com_google_absl//absl/status:statusor", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h index c2aefc6cb..48874726d 100644 --- a/eval/eval/attribute_trail.h +++ b/eval/eval/attribute_trail.h @@ -32,7 +32,7 @@ class AttributeTrail { AttributeTrail() : attribute_(nullptr) {} AttributeTrail(google::api::expr::v1alpha1::Expr root, google::protobuf::Arena* arena) : AttributeTrail(google::protobuf::Arena::Create( - arena, root, std::vector())) {} + arena, std::move(root), std::vector())) {} // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail Step(CelAttributeQualifier qualifier, @@ -52,7 +52,8 @@ class AttributeTrail { bool empty() const { return !attribute_; } private: - AttributeTrail(const CelAttribute* attribute) : attribute_(attribute) {} + explicit AttributeTrail(const CelAttribute* attribute) + : attribute_(attribute) {} const CelAttribute* attribute_; }; } // namespace runtime diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 5cdc216c6..05ae1a61b 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -12,7 +12,6 @@ namespace expr { namespace runtime { using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::Expr; namespace { @@ -75,19 +74,14 @@ absl::optional ConvertConstant(const Constant* const_expr) { absl::StatusOr> CreateConstValueStep( CelValue value, int64_t expr_id, bool comes_from_ast) { - std::unique_ptr step = - absl::make_unique(value, expr_id, comes_from_ast); - return std::move(step); + return absl::make_unique(value, expr_id, comes_from_ast); } // Factory method for Constant(Enum value) - based Execution step absl::StatusOr> CreateConstValueStep( const google::protobuf::EnumValueDescriptor* value_descriptor, int64_t expr_id) { - CelValue value = CelValue::CreateInt64(value_descriptor->number()); - - std::unique_ptr step = - absl::make_unique(value, expr_id, false); - return std::move(step); + return absl::make_unique( + CelValue::CreateInt64(value_descriptor->number()), expr_id, false); } } // namespace runtime diff --git a/eval/eval/const_value_step.h b/eval/eval/const_value_step.h index e3c685cbf..01a071295 100644 --- a/eval/eval/const_value_step.h +++ b/eval/eval/const_value_step.h @@ -3,7 +3,6 @@ #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" #include "eval/public/cel_value.h" namespace google { @@ -18,10 +17,6 @@ absl::optional ConvertConstant( absl::StatusOr> CreateConstValueStep( CelValue value, int64_t expr_id, bool comes_from_ast = true); -// Factory method for Constant(Enum value) - based Execution step -absl::StatusOr> CreateConstValueStep( - const google::protobuf::EnumValueDescriptor* value_descriptor, int64_t expr_id); - } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index aeb2499f9..619a065d5 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -156,9 +156,7 @@ absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( const google::api::expr::v1alpha1::Expr::Call*, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(expr_id); - return std::move(step); + return absl::make_unique(expr_id); } } // namespace runtime diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index f9da18357..21be321a9 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -65,9 +65,8 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateCreateListStep( const google::api::expr::v1alpha1::Expr::CreateList* create_list_expr, int64_t expr_id) { - std::unique_ptr step = absl::make_unique( - expr_id, create_list_expr->elements_size()); - return std::move(step); + return absl::make_unique(expr_id, + create_list_expr->elements_size()); } } // namespace runtime diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index abd92e8d8..68f1c727b 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -1,5 +1,7 @@ #include "eval/eval/create_struct_step.h" +#include + #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -15,9 +17,7 @@ namespace runtime { namespace { -using ::google::protobuf::Arena; using ::google::protobuf::Descriptor; -using ::google::protobuf::DescriptorPool; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Message; using ::google::protobuf::MessageFactory; @@ -272,22 +272,12 @@ absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - int64_t expr_id) { - if (!create_struct_expr->message_name().empty()) { + const Descriptor* message_desc, int64_t expr_id) { + if (message_desc != nullptr) { // TODO(issues/92): Support resolving a type name within a container. // Make message-creating step. std::vector entries; - const Descriptor* desc = - DescriptorPool::generated_pool()->FindMessageTypeByName( - create_struct_expr->message_name()); - - if (desc == nullptr) { - return absl::InvalidArgumentError( - "Error configuring message creation: message descriptor not found: " + - create_struct_expr->message_name()); - } - for (const auto& entry : create_struct_expr->entries()) { if (entry.field_key().empty()) { return absl::InvalidArgumentError( @@ -295,7 +285,7 @@ absl::StatusOr> CreateCreateStructStep( } const FieldDescriptor* field_desc = - desc->FindFieldByName(entry.field_key()); + message_desc->FindFieldByName(entry.field_key()); if (field_desc == nullptr) { return absl::InvalidArgumentError( "Error configuring message creation: field name not found"); @@ -303,12 +293,12 @@ absl::StatusOr> CreateCreateStructStep( entries.push_back({field_desc}); } - return absl::WrapUnique( - new CreateStructStepForMessage(expr_id, desc, std::move(entries))); + return std::make_unique(expr_id, message_desc, + std::move(entries)); } else { // Make map-creating step. - return absl::WrapUnique(new CreateStructStepForMap( - expr_id, create_struct_expr->entries_size())); + return std::make_unique( + expr_id, create_struct_expr->entries_size()); } } diff --git a/eval/eval/create_struct_step.h b/eval/eval/create_struct_step.h index 1c5035192..e10e6d4ab 100644 --- a/eval/eval/create_struct_step.h +++ b/eval/eval/create_struct_step.h @@ -1,6 +1,10 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ +#include + +#include "google/protobuf/descriptor.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" @@ -10,10 +14,17 @@ namespace api { namespace expr { namespace runtime { -// Factory method for CreateList - based Execution step +// Factory method for CreateStruct - based Execution step absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - int64_t expr_id); + const google::protobuf::Descriptor* message_desc, int64_t expr_id); + +inline absl::StatusOr> CreateCreateStructStep( + const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, + int64_t expr_id) { + return CreateCreateStructStep(create_struct_expr, /*message_desc=*/nullptr, + expr_id); +} } // namespace runtime } // namespace expr diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index d7eac46d8..7586b1b0b 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -3,9 +3,11 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/eval/ident_step.h" +#include "eval/public/cel_type_registry.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" @@ -38,6 +40,7 @@ absl::StatusOr RunExpression(absl::string_view field, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; + CelTypeRegistry type_registry; Expr expr0; Expr expr1; @@ -55,8 +58,12 @@ absl::StatusOr RunExpression(absl::string_view field, if (!step0_status.ok()) { return step0_status.status(); } - - auto step1_status = CreateCreateStructStep(create_struct, expr1.id()); + auto desc = type_registry.FindDescriptor(create_struct->message_name()); + if (desc == nullptr) { + return absl::Status(absl::StatusCode::kFailedPrecondition, + "missing proto message type"); + } + auto step1_status = CreateCreateStructStep(create_struct, desc, expr1.id()); if (!step1_status.ok()) { return step1_status.status(); @@ -113,7 +120,7 @@ void RunExpressionAndGetMessage(absl::string_view field, // Helper method. Creates simple pipeline containing CreateStruct step that // builds Map and runs it. absl::StatusOr RunCreateMapExpression( - const std::vector> values, + const std::vector>& values, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; Activation activation; @@ -174,14 +181,16 @@ class CreateCreateStructStepTest : public testing::TestWithParam {}; TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; + CelTypeRegistry type_registry; Expr expr1; auto create_struct = expr1.mutable_struct_expr(); create_struct->set_message_name("google.api.expr.runtime.TestMessage"); + auto desc = type_registry.FindDescriptor(create_struct->message_name()); + ASSERT_TRUE(desc != nullptr); - auto step_status = CreateCreateStructStep(create_struct, expr1.id()); - + auto step_status = CreateCreateStructStep(create_struct, desc, expr1.id()); ASSERT_OK(step_status); path.push_back(std::move(step_status.value())); diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 989c7daee..9d8977fc3 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -22,7 +22,6 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_provider.h" -#include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_function_result_set.h" @@ -194,7 +193,7 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { class EagerFunctionStep : public AbstractFunctionStep { public: - EagerFunctionStep(std::vector&& overloads, + EagerFunctionStep(std::vector& overloads, const std::string& name, size_t num_args, int64_t expr_id) : AbstractFunctionStep(name, num_args, expr_id), overloads_(overloads) {} @@ -230,7 +229,7 @@ class LazyFunctionStep : public AbstractFunctionStep { // at runtime. LazyFunctionStep(const std::string& name, size_t num_args, bool receiver_style, - const std::vector& providers, + std::vector& providers, int64_t expr_id) : AbstractFunctionStep(name, num_args, expr_id), receiver_style_(receiver_style), @@ -281,35 +280,23 @@ absl::StatusOr LazyFunctionStep::ResolveFunction( absl::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id, - const CelFunctionRegistry& function_registry, - BuilderWarnings* builder_warnings) { + std::vector& lazy_overloads) { bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); const std::string& name = call_expr->function(); - std::vector args(num_args, CelValue::Type::kAny); + return absl::make_unique(name, num_args, receiver_style, + lazy_overloads, expr_id); +} - std::vector lazy_overloads = - function_registry.FindLazyOverloads(name, receiver_style, args); - - if (!lazy_overloads.empty()) { - std::unique_ptr step = absl::make_unique( - name, num_args, receiver_style, lazy_overloads, expr_id); - return std::move(step); - } - - auto overloads = function_registry.FindOverloads(name, receiver_style, args); - - // No overloads found. - if (overloads.empty()) { - RETURN_IF_ERROR(builder_warnings->AddWarning( - absl::Status(absl::StatusCode::kInvalidArgument, - "No overloads provided for FunctionStep creation"))); - } - - std::unique_ptr step = absl::make_unique( - std::move(overloads), name, num_args, expr_id); - return std::move(step); +absl::StatusOr> CreateFunctionStep( + const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id, + std::vector& overloads) { + bool receiver_style = call_expr->has_target(); + size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); + const std::string& name = call_expr->function(); + return absl::make_unique(overloads, name, num_args, + expr_id); } } // namespace runtime diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index 36b3da61e..6342d84ab 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -8,10 +8,9 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_build_warning.h" #include "eval/public/activation.h" #include "eval/public/cel_function.h" -#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_function_provider.h" #include "eval/public/cel_value.h" namespace google { @@ -19,12 +18,18 @@ namespace api { namespace expr { namespace runtime { -// Factory method for Call - based Execution step -// Looks up function registry using data provided through Call parameter. +// Factory method for Call-based execution step where the function will be +// resolved at runtime (lazily) from an input Activation. absl::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id, - const CelFunctionRegistry& function_registry, - BuilderWarnings* builder_warnings); + std::vector& lazy_overloads); + +// Factory method for Call-based execution step where the function has been +// statically resolved from a set of eagerly functions configured in the +// CelFunctionRegistry. +absl::StatusOr> CreateFunctionStep( + const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id, + std::vector& overloads); } // namespace runtime } // namespace expr diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 296ecdd04..a27f33a8a 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -1,5 +1,8 @@ #include "eval/eval/function_step.h" +#include +#include + #include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -10,6 +13,7 @@ #include "eval/eval/ident_step.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" @@ -171,6 +175,32 @@ void AddDefaults(CelFunctionRegistry& registry) { .ok()); } +std::vector ArgumentMatcher(int argument_count) { + std::vector argument_matcher(argument_count); + for (int i = 0; i < argument_count; i++) { + argument_matcher[i] = CelValue::Type::kAny; + } + return argument_matcher; +} + +std::vector ArgumentMatcher(const Expr::Call* call) { + return ArgumentMatcher(call->has_target() ? call->args_size() + 1 + : call->args_size()); +} + +absl::StatusOr> MakeTestFunctionStep( + const Expr::Call* call, const CelFunctionRegistry& registry) { + auto argument_matcher = ArgumentMatcher(call); + auto lazy_overloads = registry.FindLazyOverloads( + call->function(), call->has_target(), argument_matcher); + if (!lazy_overloads.empty()) { + return CreateFunctionStep(call, GetExprId(), lazy_overloads); + } + auto overloads = registry.FindOverloads(call->function(), call->has_target(), + argument_matcher); + return CreateFunctionStep(call, GetExprId(), overloads); +} + // Test common functions with varying levels of unknown support. class FunctionStepTest : public testing::TestWithParam { @@ -213,12 +243,9 @@ TEST_P(FunctionStepTest, SimpleFunctionTest) { Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -254,10 +281,8 @@ TEST_P(FunctionStepTest, TestStackUnderflow) { Expr::Call call1 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step2_status); @@ -274,50 +299,6 @@ TEST_P(FunctionStepTest, TestStackUnderflow) { EXPECT_FALSE(status.ok()); } -// Test that creation fails if fail on warnings is set in the warnings -// collection. -TEST(FunctionStepTest, TestNoOverloadsOnCreation) { - CelFunctionRegistry registry; - BuilderWarnings warnings(true); - - Expr::Call call = ConstFunction::MakeCall("Const0"); - - // function step with empty overloads - auto step0_status = - CreateFunctionStep(&call, GetExprId(), registry, &warnings); - - EXPECT_FALSE(step0_status.ok()); -} - -// Test that no overloads error is warned, actual error delayed to runtime by -// default. -TEST_P(FunctionStepTest, TestNoOverloadsOnCreationDelayedError) { - CelFunctionRegistry registry; - ExecutionPath path; - Expr::Call call = ConstFunction::MakeCall("Const0"); - BuilderWarnings warnings; - - // function step with empty overloads - auto step0_status = - CreateFunctionStep(&call, GetExprId(), registry, &warnings); - - EXPECT_TRUE(step0_status.ok()); - EXPECT_THAT(warnings.warnings(), testing::SizeIs(1)); - - path.push_back(std::move(step0_status.value())); - - std::unique_ptr impl = GetExpression(std::move(path)); - - Activation activation; - google::protobuf::Arena arena; - - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - ASSERT_TRUE(value.IsError()); -} - // Test situation when no overloads match input arguments during evaluation. TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { ExecutionPath path; @@ -336,12 +317,9 @@ TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { // Add expects {int64_t, int64_t} but it's {int64_t, uint64_t}. Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -368,8 +346,6 @@ TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; AddDefaults(registry); @@ -390,12 +366,9 @@ TEST_P(FunctionStepTest, Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -443,12 +416,9 @@ TEST_P(FunctionStepTest, LazyFunctionTest) { Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -479,7 +449,6 @@ TEST_P(FunctionStepTest, Activation activation; google::protobuf::Arena arena; CelFunctionRegistry registry; - BuilderWarnings warnings; AddDefaults(registry); @@ -506,12 +475,9 @@ TEST_P(FunctionStepTest, Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -564,18 +530,17 @@ class FunctionStepTestUnknowns unknown_functions = false; break; } - return absl::make_unique(&expr, std::move(path), 0, + return absl::make_unique(&expr_, std::move(path), 0, std::set(), true, unknown_functions); } private: - Expr expr; + Expr expr_; }; TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { ExecutionPath path; - BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -584,12 +549,9 @@ TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -626,8 +588,7 @@ TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { Expr::Call call1 = SinkFunction::MakeCall(); auto step0_status = CreateIdentStep(&ident1, GetExprId()); - auto step1_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = MakeTestFunctionStep(&call1, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -660,8 +621,6 @@ TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; AddDefaults(registry); @@ -677,12 +636,9 @@ TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -725,8 +681,6 @@ MATCHER_P2(IsAdd, a, b, "") { TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( @@ -740,12 +694,9 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { Expr::Call call2 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -777,8 +728,6 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( @@ -794,20 +743,13 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { Expr::Call call2 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step3_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step4_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step5_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step6_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); + auto step3_status = MakeTestFunctionStep(&call1, registry); + auto step4_status = MakeTestFunctionStep(&call2, registry); + auto step5_status = MakeTestFunctionStep(&add_call, registry); + auto step6_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -847,8 +789,6 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( @@ -864,20 +804,13 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { Expr::Call call2 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step3_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step4_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step5_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step6_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); + auto step3_status = MakeTestFunctionStep(&call2, registry); + auto step4_status = MakeTestFunctionStep(&call1, registry); + auto step5_status = MakeTestFunctionStep(&add_call, registry); + auto step6_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -907,7 +840,7 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { auto value = status.value(); - ASSERT_TRUE(value.IsUnknownSet()) << value.ErrorOrDie()->ToString(); + ASSERT_TRUE(value.IsUnknownSet()) << *value.ErrorOrDie(); // Arguments captured. EXPECT_THAT(value.UnknownSetOrDie() ->unknown_function_results() @@ -917,8 +850,6 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; CelError error0; @@ -937,12 +868,9 @@ TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index 9c99621bd..420aebaee 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -46,7 +46,7 @@ void IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { google::api::expr::v1alpha1::Expr expr; expr.mutable_ident_expr()->set_name(name_); - *trail = AttributeTrail(expr, frame->arena()); + *trail = AttributeTrail(std::move(expr), frame->arena()); } if (frame->enable_missing_attribute_errors() && !name_.empty() && @@ -102,9 +102,7 @@ absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateIdentStep( const google::api::expr::v1alpha1::Expr::Ident* ident_expr, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(ident_expr->name(), expr_id); - return std::move(step); + return absl::make_unique(ident_expr->name(), expr_id); } } // namespace runtime diff --git a/eval/eval/jump_step.cc b/eval/eval/jump_step.cc index 5c85a645b..311ab6154 100644 --- a/eval/eval/jump_step.cc +++ b/eval/eval/jump_step.cc @@ -100,19 +100,14 @@ class BoolCheckJumpStep : public JumpStepBase { absl::StatusOr> CreateCondJumpStep( bool jump_condition, bool leave_on_stack, absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = absl::make_unique( - jump_condition, leave_on_stack, jump_offset, expr_id); - - return std::move(step); + return absl::make_unique(jump_condition, leave_on_stack, + jump_offset, expr_id); } // Factory method for Jump step. absl::StatusOr> CreateJumpStep( absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(jump_offset, expr_id); - - return std::move(step); + return absl::make_unique(jump_offset, expr_id); } // Factory method for Conditional Jump step. @@ -120,10 +115,7 @@ absl::StatusOr> CreateJumpStep( // If this value is an error or unknown, a jump is performed. absl::StatusOr> CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(jump_offset, expr_id); - - return std::move(step); + return absl::make_unique(jump_offset, expr_id); } // TODO(issues/41) Make sure Unknowns are properly supported by ternary diff --git a/eval/eval/logic_step.cc b/eval/eval/logic_step.cc index ed4da4700..045116ac9 100644 --- a/eval/eval/logic_step.cc +++ b/eval/eval/logic_step.cc @@ -111,21 +111,14 @@ absl::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { } // namespace - // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(LogicalOpStep::OpType::AND, expr_id); - - return std::move(step); + return absl::make_unique(LogicalOpStep::OpType::AND, expr_id); } // Factory method for "Or" Execution step absl::StatusOr> CreateOrStep(int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(LogicalOpStep::OpType::OR, expr_id); - - return std::move(step); + return absl::make_unique(LogicalOpStep::OpType::OR, expr_id); } } // namespace runtime diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 507c13e54..cc83b4167 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -229,9 +229,8 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateSelectStep( const google::api::expr::v1alpha1::Expr::Select* select_expr, int64_t expr_id, absl::string_view select_path) { - std::unique_ptr step = absl::make_unique( + return absl::make_unique( select_expr->field(), select_expr->test_only(), expr_id, select_path); - return std::move(step); } } // namespace runtime diff --git a/eval/eval/shadowable_value_step.cc b/eval/eval/shadowable_value_step.cc new file mode 100644 index 000000000..e9bd9aed7 --- /dev/null +++ b/eval/eval/shadowable_value_step.cc @@ -0,0 +1,45 @@ +#include "eval/eval/shadowable_value_step.h" + +#include "absl/status/statusor.h" +#include "eval/eval/expression_step_base.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +class ShadowableValueStep : public ExpressionStepBase { + public: + ShadowableValueStep(const std::string& identifier, const CelValue& value, + int64_t expr_id) + : ExpressionStepBase(expr_id), identifier_(identifier), value_(value) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + std::string identifier_; + CelValue value_; +}; + +absl::Status ShadowableValueStep::Evaluate(ExecutionFrame* frame) const { + auto var = frame->activation().FindValue(identifier_, frame->arena()); + frame->value_stack().Push(var.value_or(value_)); + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> CreateShadowableValueStep( + const std::string& identifier, const CelValue& value, int64_t expr_id) { + std::unique_ptr step = + absl::make_unique(identifier, value, expr_id); + return std::move(step); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/shadowable_value_step.h b/eval/eval/shadowable_value_step.h new file mode 100644 index 000000000..7671e159d --- /dev/null +++ b/eval/eval/shadowable_value_step.h @@ -0,0 +1,25 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ + +#include "absl/status/statusor.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/activation.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Create an identifier resolution step with a default value that may be +// shadowed by an identifier of the same name within the runtime-provided +// Activation. +absl::StatusOr> CreateShadowableValueStep( + const std::string& identifier, const CelValue& value, int64_t expr_id); + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc new file mode 100644 index 000000000..15e25f8bc --- /dev/null +++ b/eval/eval/shadowable_value_step_test.cc @@ -0,0 +1,79 @@ +#include "eval/eval/shadowable_value_step.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/statusor.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/cel_value.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::protobuf::Arena; +using testing::Eq; + +absl::StatusOr RunShadowableExpression(const std::string& identifier, + const CelValue& value, + const Activation& activation, + Arena* arena) { + auto step_status = CreateShadowableValueStep(identifier, value, 1); + if (!step_status.ok()) { + return step_status.status(); + } + + ExecutionPath path; + path.push_back(std::move(step_status.value())); + + google::api::expr::v1alpha1::Expr dummy_expr; + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}); + return impl.Evaluate(activation, arena); +} + +TEST(ShadowableValueStepTest, TestEvaluateNoShadowing) { + std::string type_name = "google.api.expr.runtime.TestMessage"; + + Activation activation; + Arena arena; + + auto type_value = + CelValue::CreateCelType(CelValue::CelTypeHolder(&type_name)); + auto status = + RunShadowableExpression(type_name, type_value, activation, &arena); + ASSERT_OK(status); + + auto value = status.value(); + ASSERT_TRUE(value.IsCelType()); + EXPECT_THAT(value.CelTypeOrDie().value(), Eq(type_name)); +} + +TEST(ShadowableValueStepTest, TestEvaluateShadowedIdentifier) { + std::string type_name = "int"; + auto shadow_value = CelValue::CreateInt64(1024L); + + Activation activation; + activation.InsertValue(type_name, shadow_value); + Arena arena; + + auto type_value = + CelValue::CreateCelType(CelValue::CelTypeHolder(&type_name)); + auto status = + RunShadowableExpression(type_name, type_value, activation, &arena); + ASSERT_OK(status); + + auto value = status.value(); + ASSERT_TRUE(value.IsInt64()); + EXPECT_THAT(value.Int64OrDie(), Eq(1024L)); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/ternary_step.cc b/eval/eval/ternary_step.cc index a52430ad3..420cf10e5 100644 --- a/eval/eval/ternary_step.cc +++ b/eval/eval/ternary_step.cc @@ -69,10 +69,7 @@ absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateTernaryStep( int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(expr_id); - - return step; + return absl::make_unique(expr_id); } } // namespace runtime diff --git a/eval/public/BUILD b/eval/public/BUILD index 2ec095c48..2effb7e8e 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -180,6 +180,7 @@ cc_library( ":cel_function_adapter", ":cel_function_registry", ":cel_options", + "//base:unilib", "//eval/public/containers:container_backed_list_impl", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", @@ -216,6 +217,7 @@ cc_library( ":activation", ":cel_function", ":cel_function_registry", + ":cel_type_registry", ":cel_value", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -336,7 +338,6 @@ cc_test( ":cel_value", ":unknown_attribute_set", ":unknown_set", - "//base:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -449,6 +450,34 @@ cc_test( ], ) +cc_library( + name = "cel_type_registry", + srcs = ["cel_type_registry.cc"], + hdrs = ["cel_type_registry.h"], + deps = [ + ":cel_value", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "cel_type_registry_test", + srcs = ["cel_type_registry_test.cc"], + deps = [ + ":cel_type_registry", + "//eval/testutil:test_message_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "builtin_func_test", size = "small", diff --git a/eval/public/activation.h b/eval/public/activation.h index a6346699e..8e973140f 100644 --- a/eval/public/activation.h +++ b/eval/public/activation.h @@ -84,8 +84,7 @@ class Activation : public BaseActivation { google::protobuf::Arena* arena) const override; bool IsPathUnknown(absl::string_view path) const override { - return google::protobuf::util::FieldMaskUtil::IsPathInFieldMask(std::string(path), - unknown_paths_); + return google::protobuf::util::FieldMaskUtil::IsPathInFieldMask(path.data(), unknown_paths_); } // Insert a function into the activation (ie a lazily bound function). Returns diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 34554d5d4..86fa2adbe 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -16,6 +16,7 @@ #include "eval/public/cel_options.h" #include "eval/public/containers/container_backed_list_impl.h" #include "re2/re2.h" +#include "base/unilib.h" namespace google { namespace api { @@ -1145,16 +1146,15 @@ absl::Status RegisterStringConversionFunctions( return absl::OkStatus(); } - // TODO(issues/82): ensure the bytes conversion to string handles UTF-8 - // properly, and avoids unncessary allocations. - // bytes -> string - auto status = FunctionAdapter:: - CreateAndRegister( + auto status = + FunctionAdapter::CreateAndRegister( builtin::kString, false, - [](Arena* arena, - CelValue::BytesHolder value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, std::string(value.value()))); + [](Arena* arena, CelValue::BytesHolder value) -> CelValue { + if (UniLib::IsStructurallyValid(value.value())) { + return CelValue::CreateStringView(value.value()); + } + return CreateErrorValue(arena, "invalid UTF-8 bytes value", + absl::StatusCode::kInvalidArgument); }, registry); if (!status.ok()) return status; diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 5b3f7b267..008f3bf08 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -1686,6 +1686,15 @@ TEST_F(BuiltinsTest, BytesToString) { EXPECT_EQ(result_value.StringOrDie().value(), "abcd"); } +TEST_F(BuiltinsTest, BytesToStringInvalid) { + std::string input = "\xFF"; + std::vector args = {CelValue::CreateBytes(&input)}; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE( + PerformRun(builtin::kString, {}, args, &result_value)); + ASSERT_TRUE(result_value.IsError()); +} + TEST_F(BuiltinsTest, StringToString) { std::string input = "abcd"; std::vector args = {CelValue::CreateString(&input)}; diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 5e23df1ca..b8cd42dd0 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -18,6 +18,8 @@ std::unique_ptr CreateCelExpressionBuilder( builder->set_comprehension_max_iterations( options.comprehension_max_iterations); builder->set_fail_on_warnings(options.fail_on_warnings); + builder->set_enable_qualified_type_identifiers( + options.enable_qualified_type_identifiers); switch (options.unknown_processing) { case UnknownProcessingOptions::kAttributeAndFunction: diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 524051cc1..8a5372bca 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -2,6 +2,7 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ #include +#include #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" @@ -10,6 +11,7 @@ #include "eval/public/activation.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" namespace google { @@ -76,7 +78,9 @@ class CelExpression { class CelExpressionBuilder { public: CelExpressionBuilder() - : registry_(absl::make_unique()), container_("") {} + : func_registry_(absl::make_unique()), + type_registry_(absl::make_unique()), + container_("") {} virtual ~CelExpressionBuilder() {} @@ -118,21 +122,17 @@ class CelExpressionBuilder { // CelFunction registry. Extension function should be registered with it // prior to expression creation. - CelFunctionRegistry* GetRegistry() const { return registry_.get(); } + CelFunctionRegistry* GetRegistry() const { return func_registry_.get(); } - // Enums registered with the builder. - const std::set& resolvable_enums() const { - return resolvable_enums_; - } + // CEL Type registry. Provides a means to resolve the CEL built-in types to + // CelValue instances, and to extend the set of types and enums known to + // expressions by registering them ahead of time. + CelTypeRegistry* GetTypeRegistry() const { return type_registry_.get(); } // Add Enum to the list of resolvable by the builder. - void AddResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { - resolvable_enums_.emplace(enum_descriptor); - } - - // Remove Enum from the list of resolvable by the builder. - void RemoveResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { - resolvable_enums_.erase(enum_descriptor); + void ABSL_DEPRECATED("Use GetTypeRegistry()->Register() instead") + AddResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { + type_registry_->Register(enum_descriptor); } void set_container(std::string container) { @@ -142,8 +142,8 @@ class CelExpressionBuilder { absl::string_view container() const { return container_; } private: - std::unique_ptr registry_; - std::set resolvable_enums_; + std::unique_ptr func_registry_; + std::unique_ptr type_registry_; std::string container_; }; diff --git a/eval/public/cel_function_registry.cc b/eval/public/cel_function_registry.cc index 34202afe4..04755b490 100644 --- a/eval/public/cel_function_registry.cc +++ b/eval/public/cel_function_registry.cc @@ -41,7 +41,7 @@ std::vector CelFunctionRegistry::FindOverloads( const std::vector& types) const { std::vector matched_funcs; - auto overloads = functions_.find(std::string(name)); + auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } @@ -60,7 +60,7 @@ std::vector CelFunctionRegistry::FindLazyOverloads( const std::vector& types) const { std::vector matched_funcs; - auto overloads = functions_.find(std::string(name)); + auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index de10230cb..dcd774c6c 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -72,6 +72,14 @@ struct InterpreterOptions { // Treat builder warnings as fatal errors. bool fail_on_warnings = true; + + // Enable the resolution of qualified type identifiers as type values instead + // of field selections. + // + // This toggle may cause certain identifiers which overlap with CEL built-in + // type or with protobuf message types linked into the binary to be resolved + // as static type values rather than as per-eval variables. + bool enable_qualified_type_identifiers = false; }; } // namespace runtime diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc new file mode 100644 index 000000000..565ae3e27 --- /dev/null +++ b/eval/public/cel_type_registry.cc @@ -0,0 +1,75 @@ +#include "eval/public/cel_type_registry.h" + +#include "google/protobuf/descriptor.h" +#include "absl/container/node_hash_set.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +const absl::node_hash_set& GetCoreTypes() { + static const auto* const kCoreTypes = + new absl::node_hash_set{{"bool"}, + {"bytes"}, + {"double"}, + {"google.protobuf.Duration"}, + {"google.protobuf.Timestamp"}, + {"int"}, + {"list"}, + {"map"}, + {"null_type"}, + {"string"}, + {"type"}, + {"uint"}}; + return *kCoreTypes; +} + +} // namespace + +CelTypeRegistry::CelTypeRegistry() : types_(GetCoreTypes()), enums_() {} + +void CelTypeRegistry::Register(std::string fully_qualified_type_name) { + // Registers the fully qualified type name as a CEL type. + types_.insert(std::move(fully_qualified_type_name)); +} + +void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { + enums_.insert(enum_descriptor); +} + +const google::protobuf::Descriptor* CelTypeRegistry::FindDescriptor( + absl::string_view fully_qualified_type_name) const { + return google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + fully_qualified_type_name.data()); +} + +absl::optional CelTypeRegistry::FindType( + absl::string_view fully_qualified_type_name) const { + // Searches through explicitly registered type names first. + auto type = types_.find(fully_qualified_type_name); + // The CelValue returned by this call will remain valid as long as the + // CelExpression and associated builder stay in scope. + if (type != types_.end()) { + return CelValue::CreateCelTypeView(*type); + } + + // By default falls back to looking at whether the protobuf descriptor is + // linked into the binary. In the future, this functionality may be disabled, + // but this is most consistent with the current CEL C++ behavior. + auto desc = FindDescriptor(fully_qualified_type_name); + if (desc != nullptr) { + return CelValue::CreateCelTypeView(desc->full_name()); + } + return absl::nullopt; +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h new file mode 100644 index 000000000..f8d236895 --- /dev/null +++ b/eval/public/cel_type_registry.h @@ -0,0 +1,74 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ + +#include "google/protobuf/descriptor.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// CelTypeRegistry manages the set of registered types available for use within +// object literal construction, enum comparisons, and type testing. +// +// The CelTypeRegistry is intended to live for the duration of all CelExpression +// values created by a given CelExpressionBuilder and one is created by default +// within the standard CelExpressionBuilder. +// +// By default, all core CEL types and all linked protobuf message types are +// implicitly registered by way of the generated descriptor pool. In the future, +// such type registrations may be explicit to avoid accidentally exposing linked +// protobuf types to CEL which were intended to remain internal. +class CelTypeRegistry { + public: + CelTypeRegistry(); + + ~CelTypeRegistry() {} + + // Register a fully qualified type name as a valid type for use within CEL + // expressions. + // + // This call establishes a CelValue type instance that can be used in runtime + // comparisons, and may have implications in the future about which protobuf + // message types linked into the binary may also be used by CEL. + // + // Type registration must be performed prior to CelExpression creation. + void Register(std::string fully_qualified_type_name); + + // Register an enum whose values may be used within CEL expressions. + // + // Enum registration must be performed prior to CelExpression creation. + void Register(const google::protobuf::EnumDescriptor* enum_descriptor); + + // Find a protobuf Descriptor given a fully qualified protobuf type name. + const google::protobuf::Descriptor* FindDescriptor( + absl::string_view fully_qualified_type_name) const; + + // Find a type's CelValue instance by its fully qualified name. + absl::optional FindType( + absl::string_view fully_qualified_type_name) const; + + // Return the set of enums configured within the type registry. + inline const absl::flat_hash_set& Enums() + const { + return enums_; + } + + private: + // pointer-stability is required for the strings in the types set, which is + // why a node_hash_set is used instead of another container type. + absl::node_hash_set types_; + absl::flat_hash_set enums_; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc new file mode 100644 index 000000000..3117722da --- /dev/null +++ b/eval/public/cel_type_registry_test.cc @@ -0,0 +1,86 @@ +#include "eval/public/cel_type_registry.h" + +#include "google/protobuf/any.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "eval/testutil/test_message.pb.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using testing::Eq; + +TEST(CelTypeRegistryTest, TestRegisterEnumDescriptor) { + CelTypeRegistry registry; + registry.Register(TestMessage::TestEnum_descriptor()); + + absl::flat_hash_set enum_set; + for (auto enum_desc : registry.Enums()) { + enum_set.insert(enum_desc->full_name()); + } + absl::flat_hash_set expected_set; + expected_set.insert({"google.api.expr.runtime.TestMessage.TestEnum"}); + EXPECT_THAT(enum_set, Eq(expected_set)); +} + +TEST(CelTypeRegistryTest, TestRegisterTypeName) { + CelTypeRegistry registry; + + // Register the type, scoping the type name lifecycle to the nested block. + { + std::string custom_type = "custom_type"; + registry.Register(custom_type); + } + + auto type = registry.FindType("custom_type"); + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type.value().IsCelType()); + EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("custom_type")); +} + +TEST(CelTypeRegistryTest, TestFindDescriptorFound) { + CelTypeRegistry registry; + auto desc = registry.FindDescriptor("google.protobuf.Any"); + ASSERT_TRUE(desc != nullptr); + EXPECT_THAT(desc->full_name(), Eq("google.protobuf.Any")); +} + +TEST(CelTypeRegistryTest, TestFindDescriptorNotFound) { + CelTypeRegistry registry; + auto desc = registry.FindDescriptor("missing.MessageType"); + EXPECT_TRUE(desc == nullptr); +} + +TEST(CelTypeRegistryTest, TestFindTypeCoreTypeFound) { + CelTypeRegistry registry; + auto type = registry.FindType("int"); + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type.value().IsCelType()); + EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("int")); +} + +TEST(CelTypeRegistryTest, TestFindTypeProtobufTypeFound) { + CelTypeRegistry registry; + auto type = registry.FindType("google.protobuf.Any"); + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type.value().IsCelType()); + EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("google.protobuf.Any")); +} + +TEST(CelTypeRegistryTest, TestFindTypeNotRegisteredTypeNotFound) { + CelTypeRegistry registry; + auto type = registry.FindType("missing.MessageType"); + EXPECT_FALSE(type.has_value()); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 76be6f511..a7c34bdce 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -162,6 +162,9 @@ class CelValue { static CelValue CreateString(StringHolder holder) { return CelValue(holder); } + // Returns a string value from a string_view. Warning: the caller is + // responsible for the lifecycle of the backing string. Prefer CreateString + // instead. static CelValue CreateStringView(absl::string_view value) { return CelValue(StringHolder(value)); } @@ -205,6 +208,14 @@ class CelValue { return CelValue(holder); } + static CelValue CreateCelTypeView(absl::string_view value) { + // This factory method is used for dealing with string references which + // come from protobuf objects or other containers which promise pointer + // stability. In general, this is a risky method to use and should not + // be invoked outside the core CEL library. + return CelValue(CelTypeHolder(value)); + } + static CelValue CreateError(const CelError *value) { CheckNullPointer(value, Type::kError); return CelValue(value); diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index f9538965e..dee3d667e 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -15,7 +15,6 @@ namespace expr { namespace runtime { using testing::Eq; -using testing::UnorderedPointwise; class DummyMap : public CelMap { public: @@ -257,6 +256,13 @@ TEST(CelValueTest, TestCelType) { EXPECT_THAT(value_bytes.type(), Eq(CelValue::Type::kBytes)); EXPECT_THAT(value_bytes.ObtainCelType().CelTypeOrDie().value(), Eq("bytes")); + std::string msg_type_str = "google.api.expr.runtime.TestMessage"; + CelValue msg_type = CelValue::CreateCelTypeView(msg_type_str); + EXPECT_TRUE(msg_type.IsCelType()); + EXPECT_THAT(msg_type.CelTypeOrDie().value(), + Eq("google.api.expr.runtime.TestMessage")); + EXPECT_THAT(msg_type.type(), Eq(CelValue::Type::kCelType)); + UnknownSet unknown_set; CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 78c29f463..77d318c27 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -155,7 +155,7 @@ CelValue ValueFromMessage(const Any* any_value, Arena* arena) { std::string full_name = std::string(type_url.substr(pos + 1)); const Descriptor* nested_descriptor = - DescriptorPool::generated_pool()->FindMessageTypeByName(full_name); + DescriptorPool::generated_pool()->FindMessageTypeByName(full_name.data()); if (nested_descriptor == nullptr) { // Descriptor not found for the type diff --git a/parser/BUILD b/parser/BUILD index f6378ae37..3f2be39f1 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -97,6 +97,7 @@ cc_library( ":cel_cc_parser", "//common:operators", "@antlr4_runtimes//:cpp", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/parser/balancer.cc b/parser/balancer.cc index 88cfbd6d9..6cd85a2e5 100644 --- a/parser/balancer.cc +++ b/parser/balancer.cc @@ -9,10 +9,13 @@ namespace parser { ExpressionBalancer::ExpressionBalancer(std::shared_ptr sf, std::string function, Expr expr) - : sf_(sf), function_(function), terms_{expr}, ops_{} {} + : sf_(std::move(sf)), + function_(std::move(function)), + terms_{std::move(expr)}, + ops_{} {} void ExpressionBalancer::addTerm(int64_t op, Expr term) { - terms_.push_back(term); + terms_.push_back(std::move(term)); ops_.push_back(op); } @@ -39,7 +42,8 @@ Expr ExpressionBalancer::balancedTree(int lo, int hi) { } else { right = balancedTree(mid + 1, hi); } - return sf_->newGlobalCall(ops_[mid], function_, {left, right}); + return sf_->newGlobalCall(ops_[mid], function_, + {std::move(left), std::move(right)}); } } // namespace parser diff --git a/parser/macro.cc b/parser/macro.cc index 7eb8559c1..c7e72898f 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -26,8 +26,8 @@ std::vector Macro::AllMacros() { // The macro "has(m.f)" which tests the presence of a field, avoiding the // need to specify the field as a string. Macro(CelOperator::HAS, 1, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { if (!args.empty() && args[0].has_select_expr()) { const auto& sel_expr = args[0].select_expr(); return sf->newPresenceTestForMacro(macro_id, sel_expr.operand(), @@ -43,8 +43,8 @@ std::vector Macro::AllMacros() { // in range the predicate holds. Macro( CelOperator::ALL, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newQuantifierExprForMacro(SourceFactory::QUANTIFIER_ALL, macro_id, target, args); }, @@ -54,8 +54,8 @@ std::vector Macro::AllMacros() { // one element in range the predicate holds. Macro( CelOperator::EXISTS, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newQuantifierExprForMacro( SourceFactory::QUANTIFIER_EXISTS, macro_id, target, args); }, @@ -65,8 +65,8 @@ std::vector Macro::AllMacros() { // exactly one element in range the predicate holds. Macro( CelOperator::EXISTS_ONE, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newQuantifierExprForMacro( SourceFactory::QUANTIFIER_EXISTS_ONE, macro_id, target, args); }, @@ -77,8 +77,8 @@ std::vector Macro::AllMacros() { // the range. Macro( CelOperator::MAP, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newMapForMacro(macro_id, target, args); }, /* receiver style*/ true), @@ -89,8 +89,8 @@ std::vector Macro::AllMacros() { // variables are filtered out. Macro( CelOperator::MAP, 3, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newMapForMacro(macro_id, target, args); }, /* receiver style*/ true), @@ -100,8 +100,8 @@ std::vector Macro::AllMacros() { // predicate is false. Macro( CelOperator::FILTER, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newFilterExprForMacro(macro_id, target, args); }, /* receiver style*/ true), diff --git a/parser/macro.h b/parser/macro.h index 7a30e725c..6277593da 100644 --- a/parser/macro.h +++ b/parser/macro.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "google/api/expr/v1alpha1/syntax.pb.h" @@ -19,11 +20,11 @@ class SourceFactory; // MacroExpander converts the target and args of a function call that matches a // Macro. // -// Note: when the Macros.IsReceiverStyle() is true, the target argument will be -// empty. +// Note: when the Macros.IsReceiverStyle() is true, the target argument will +// be Expr::default_instance(). using MacroExpander = - std::function sf, int64_t macro_id, Expr*, - const std::vector&)>; + std::function& sf, int64_t macro_id, + const Expr&, const std::vector&)>; // Macro interface for describing the function signature to match and the // MacroExpander to apply. @@ -39,7 +40,7 @@ class Macro { receiver_style_(receiver_style), var_arg_style_(false), arg_count_(arg_count), - expander_(expander) {} + expander_(std::move(expander)) {} Macro(const std::string& function, MacroExpander expander, bool receiver_style = false) @@ -47,7 +48,7 @@ class Macro { receiver_style_(receiver_style), var_arg_style_(true), arg_count_(0), - expander_(expander) {} + expander_(std::move(expander)) {} // Function name to match. std::string function() const { return function_; } @@ -73,9 +74,9 @@ class Macro { // parsed call signature. const MacroExpander& expander() const { return expander_; } - Expr expand(std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return expander_(sf, macro_id, target, args); + Expr expand(const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { + return expander_(std::move(sf), macro_id, target, args); } static std::vector AllMacros(); diff --git a/parser/parser.cc b/parser/parser.cc index 35b804f42..40dce202e 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -14,23 +14,27 @@ namespace api { namespace expr { namespace parser { -using antlr4::ANTLRInputStream; -using antlr4::CommonTokenStream; -using antlr4::ParseCancellationException; -using antlr4::ParserRuleContext; +using ::antlr4::ANTLRInputStream; +using ::antlr4::CommonTokenStream; +using ::antlr4::ParseCancellationException; +using ::antlr4::ParserRuleContext; -using antlr4::tree::ErrorNode; -using antlr4::tree::TerminalNode; +using ::antlr4::tree::ErrorNode; +using ::antlr4::tree::ParseTreeListener; +using ::antlr4::tree::TerminalNode; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::ParsedExpr; +using ::google::api::expr::v1alpha1::Expr; +using ::google::api::expr::v1alpha1::ParsedExpr; + +using ::cel_grammar::CelLexer; +using ::cel_grammar::CelParser; namespace { // ExprRecursionListener extends the standard ANTLR CelParser to ensure that // recursive entries into the 'expr' rule are limited to a configurable depth so // as to prevent stack overflows. -class ExprRecursionListener : public ::antlr4::tree::ParseTreeListener { +class ExprRecursionListener : public ParseTreeListener { public: ExprRecursionListener( const int max_recursion_depth = kDefaultMaxRecursionDepth) @@ -50,7 +54,7 @@ void ExprRecursionListener::enterEveryRule(ParserRuleContext* ctx) { // Throw a ParseCancellationException since the parsing would otherwise // continue if this were treated as a syntax error and the problem would // continue to manifest. - if (ctx->getRuleIndex() == ::cel_grammar::CelParser::RuleExpr) { + if (ctx->getRuleIndex() == CelParser::RuleExpr) { if (recursion_depth_ >= max_recursion_depth_) { throw ParseCancellationException( absl::StrFormat("Expression recursion limit exceeded. limit: %d", @@ -61,7 +65,7 @@ void ExprRecursionListener::enterEveryRule(ParserRuleContext* ctx) { } void ExprRecursionListener::exitEveryRule(ParserRuleContext* ctx) { - if (ctx->getRuleIndex() == ::cel_grammar::CelParser::RuleExpr) { + if (ctx->getRuleIndex() == CelParser::RuleExpr) { recursion_depth_--; } } @@ -91,9 +95,9 @@ absl::StatusOr EnrichedParse( const std::string& expression, const std::vector& macros, const std::string& description, const int max_recursion_depth) { ANTLRInputStream input(expression); - ::cel_grammar::CelLexer lexer(&input); + CelLexer lexer(&input); CommonTokenStream tokens(&lexer); - ::cel_grammar::CelParser parser(&tokens); + CelParser parser(&tokens); ExprRecursionListener listener(max_recursion_depth); ParserVisitor visitor(description, expression, max_recursion_depth, macros); @@ -107,12 +111,12 @@ absl::StatusOr EnrichedParse( // std::shared_ptr error_strategy(new BailErrorStrategy()); // parser.setErrorHandler(error_strategy); - ::cel_grammar::CelParser::StartContext* root; + CelParser::StartContext* root; try { root = parser.start(); - } catch (ParseCancellationException& e) { + } catch (const ParseCancellationException& e) { return absl::CancelledError(e.what()); - } catch (std::exception& e) { + } catch (const std::exception& e) { return absl::AbortedError(e.what()); } @@ -124,10 +128,11 @@ absl::StatusOr EnrichedParse( // root is deleted as part of the parser context ParsedExpr parsed_expr; - parsed_expr.mutable_expr()->CopyFrom(expr); - parsed_expr.mutable_source_info()->CopyFrom(visitor.sourceInfo()); + *(parsed_expr.mutable_expr()) = std::move(expr); auto enriched_source_info = visitor.enrichedSourceInfo(); - return VerboseParsedExpr(parsed_expr, enriched_source_info); + *(parsed_expr.mutable_source_info()) = visitor.sourceInfo(); + return VerboseParsedExpr(std::move(parsed_expr), + std::move(enriched_source_info)); } } // namespace parser diff --git a/parser/parser.h b/parser/parser.h index 46bb561d1..7227c8fed 100644 --- a/parser/parser.h +++ b/parser/parser.h @@ -16,10 +16,10 @@ constexpr int kDefaultMaxRecursionDepth = 250; class VerboseParsedExpr { public: - VerboseParsedExpr(const google::api::expr::v1alpha1::ParsedExpr& parsed_expr, - const EnrichedSourceInfo& enriched_source_info) - : parsed_expr_(parsed_expr), - enriched_source_info_(enriched_source_info) {} + VerboseParsedExpr(google::api::expr::v1alpha1::ParsedExpr parsed_expr, + EnrichedSourceInfo enriched_source_info) + : parsed_expr_(std::move(parsed_expr)), + enriched_source_info_(std::move(enriched_source_info)) {} const google::api::expr::v1alpha1::ParsedExpr& parsed_expr() const { return parsed_expr_; diff --git a/parser/source_factory.cc b/parser/source_factory.cc index 3573f5289..3052b2407 100644 --- a/parser/source_factory.cc +++ b/parser/source_factory.cc @@ -4,6 +4,7 @@ #include #include "google/protobuf/struct.pb.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" @@ -36,9 +37,10 @@ SourceFactory::SourceFactory(const std::string& expression) int64_t SourceFactory::id(const antlr4::Token* token) { int64_t new_id = next_id_; positions_.emplace( - new_id, SourceLocation{(int32_t)token->getLine(), - (int32_t)token->getCharPositionInLine(), - (int32_t)token->getStopIndex(), line_offsets_}); + new_id, + SourceLocation{static_cast(token->getLine()), + static_cast(token->getCharPositionInLine()), + static_cast(token->getStopIndex()), line_offsets_}); next_id_ += 1; return new_id; } @@ -86,9 +88,8 @@ Expr SourceFactory::newGlobalCall(int64_t id, const std::string& function, Expr expr = newExpr(id); auto call_expr = expr.mutable_call_expr(); call_expr->set_function(function); - std::for_each(args.begin(), args.end(), [&call_expr](const Expr& e) { - call_expr->add_args()->CopyFrom(e); - }); + std::for_each(args.begin(), args.end(), + [&call_expr](const Expr& e) { *call_expr->add_args() = e; }); return expr; } @@ -99,15 +100,14 @@ Expr SourceFactory::newGlobalCallForMacro(int64_t macro_id, } Expr SourceFactory::newReceiverCall(int64_t id, const std::string& function, - Expr& target, + const Expr& target, const std::vector& args) { Expr expr = newExpr(id); auto call_expr = expr.mutable_call_expr(); call_expr->set_function(function); - call_expr->mutable_target()->CopyFrom(target); - std::for_each(args.begin(), args.end(), [&call_expr](const Expr& e) { - call_expr->add_args()->CopyFrom(e); - }); + *call_expr->mutable_target() = target; + std::for_each(args.begin(), args.end(), + [&call_expr](const Expr& e) { *call_expr->add_args() = e; }); return expr; } @@ -130,7 +130,7 @@ Expr SourceFactory::newSelect( const std::string& field) { Expr expr = newExpr(ctx->op); auto select_expr = expr.mutable_select_expr(); - select_expr->mutable_operand()->CopyFrom(operand); + *select_expr->mutable_operand() = operand; select_expr->set_field(field); return expr; } @@ -139,14 +139,14 @@ Expr SourceFactory::newPresenceTestForMacro(int64_t macro_id, const Expr& operan const std::string& field) { Expr expr = newExpr(nextMacroId(macro_id)); auto select_expr = expr.mutable_select_expr(); - select_expr->mutable_operand()->CopyFrom(operand); + *select_expr->mutable_operand() = operand; select_expr->set_field(field); select_expr->set_test_only(true); return expr; } Expr SourceFactory::newObject( - int64_t obj_id, std::string type_name, + int64_t obj_id, const std::string& type_name, const std::vector& entries) { auto expr = newExpr(obj_id); auto struct_expr = expr.mutable_struct_expr(); @@ -163,7 +163,7 @@ Expr::CreateStruct::Entry SourceFactory::newObjectField( Expr::CreateStruct::Entry entry; entry.set_id(field_id); entry.set_field_key(field); - entry.mutable_value()->CopyFrom(value); + *entry.mutable_value() = value; return entry; } @@ -176,12 +176,12 @@ Expr SourceFactory::newComprehension(int64_t id, const std::string& iter_var, Expr expr = newExpr(id); auto comp_expr = expr.mutable_comprehension_expr(); comp_expr->set_iter_var(iter_var); - comp_expr->mutable_iter_range()->CopyFrom(iter_range); + *comp_expr->mutable_iter_range() = iter_range; comp_expr->set_accu_var(accu_var); - comp_expr->mutable_accu_init()->CopyFrom(accu_init); - comp_expr->mutable_loop_condition()->CopyFrom(condition); - comp_expr->mutable_loop_step()->CopyFrom(step); - comp_expr->mutable_result()->CopyFrom(result); + *comp_expr->mutable_accu_init() = accu_init; + *comp_expr->mutable_loop_condition() = condition; + *comp_expr->mutable_loop_step() = step; + *comp_expr->mutable_result() = result; return expr; } @@ -197,14 +197,13 @@ Expr SourceFactory::foldForMacro(int64_t macro_id, const std::string& iter_var, Expr SourceFactory::newList(int64_t list_id, const std::vector& elems) { auto expr = newExpr(list_id); auto list_expr = expr.mutable_list_expr(); - std::for_each(elems.begin(), elems.end(), [list_expr](const Expr& e) { - list_expr->add_elements()->CopyFrom(e); - }); + std::for_each(elems.begin(), elems.end(), + [list_expr](const Expr& e) { *list_expr->add_elements() = e; }); return expr; } Expr SourceFactory::newQuantifierExprForMacro( - SourceFactory::QuantifierKind kind, int64_t macro_id, Expr* target, + SourceFactory::QuantifierKind kind, int64_t macro_id, const Expr& target, const std::vector& args) { if (args.empty()) { return Expr(); @@ -264,11 +263,11 @@ Expr SourceFactory::newQuantifierExprForMacro( break; } } - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, + return foldForMacro(macro_id, v, target, AccumulatorName, init, condition, step, result); } -Expr SourceFactory::newFilterExprForMacro(int64_t macro_id, Expr* target, +Expr SourceFactory::newFilterExprForMacro(int64_t macro_id, const Expr& target, const std::vector& args) { if (args.empty()) { return Expr(); @@ -291,7 +290,7 @@ Expr SourceFactory::newFilterExprForMacro(int64_t macro_id, Expr* target, {accu_expr, newListForMacro(macro_id, {args[0]})}); step = newGlobalCallForMacro(macro_id, CelOperator::CONDITIONAL, {filter, step, accu_expr}); - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, + return foldForMacro(macro_id, v, target, AccumulatorName, init, condition, step, accu_expr); } @@ -311,7 +310,7 @@ Expr SourceFactory::newMap( return expr; } -Expr SourceFactory::newMapForMacro(int64_t macro_id, Expr* target, +Expr SourceFactory::newMapForMacro(int64_t macro_id, const Expr& target, const std::vector& args) { if (args.empty()) { return Expr(); @@ -345,7 +344,7 @@ Expr SourceFactory::newMapForMacro(int64_t macro_id, Expr* target, step = newGlobalCallForMacro(macro_id, CelOperator::CONDITIONAL, {filter, step, accu_expr}); } - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, + return foldForMacro(macro_id, v, target, AccumulatorName, init, condition, step, accu_expr); } @@ -354,8 +353,8 @@ Expr::CreateStruct::Entry SourceFactory::newMapEntry(int64_t entry_id, const Expr& value) { Expr::CreateStruct::Entry entry; entry.set_id(entry_id); - entry.mutable_map_key()->CopyFrom(key); - entry.mutable_value()->CopyFrom(value); + *entry.mutable_map_key() = key; + *entry.mutable_value() = value; return entry; } @@ -508,12 +507,11 @@ std::string SourceFactory::errorMessage(const std::string& description, } bool SourceFactory::isReserved(const std::string& ident_name) { - static std::vector reserved_words = { - "as", "break", "const", "continue", "else", "false", "for", - "function", "if", "import", "in", "let", "loop", "package", - "namespace", "null", "return", "true", "var", "void", "while"}; - return std::find(reserved_words.begin(), reserved_words.end(), ident_name) != - reserved_words.end(); + static const auto* reserved_words = new absl::flat_hash_set( + {"as", "break", "const", "continue", "else", "false", "for", + "function", "if", "import", "in", "let", "loop", "package", + "namespace", "null", "return", "true", "var", "void", "while"}); + return reserved_words->find(ident_name) != reserved_words->end(); } google::api::expr::v1alpha1::SourceInfo SourceFactory::sourceInfo() const { @@ -537,7 +535,7 @@ EnrichedSourceInfo SourceFactory::enrichedSourceInfo() const { [&offset](const std::pair& loc) { offset.insert({loc.first, {loc.second.offset, loc.second.offset_end}}); }); - return EnrichedSourceInfo(offset); + return EnrichedSourceInfo(std::move(offset)); } void SourceFactory::calcLineOffsets(const std::string& expression) { diff --git a/parser/source_factory.h b/parser/source_factory.h index 79d766f45..34617c951 100644 --- a/parser/source_factory.h +++ b/parser/source_factory.h @@ -19,8 +19,8 @@ using google::api::expr::v1alpha1::Expr; class EnrichedSourceInfo { public: - EnrichedSourceInfo(const std::map>& offsets) - : offsets_(offsets) {} + EnrichedSourceInfo(std::map> offsets) + : offsets_(std::move(offsets)) {} const std::map>& offsets() const { return offsets_; @@ -56,7 +56,7 @@ class SourceFactory { struct Error { Error(std::string message, SourceLocation location) - : message(message), location(location) {} + : message(std::move(message)), location(location) {} std::string message; SourceLocation location; }; @@ -86,15 +86,15 @@ class SourceFactory { const std::vector& args); Expr newGlobalCallForMacro(int64_t macro_id, const std::string& function, const std::vector& args); - Expr newReceiverCall(int64_t id, const std::string& function, Expr& target, - const std::vector& args); + Expr newReceiverCall(int64_t id, const std::string& function, + const Expr& target, const std::vector& args); Expr newIdent(const antlr4::Token* token, const std::string& ident_name); Expr newIdentForMacro(int64_t macro_id, const std::string& ident_name); Expr newSelect(::cel_grammar::CelParser::SelectOrCallContext* ctx, Expr& operand, const std::string& field); Expr newPresenceTestForMacro(int64_t macro_id, const Expr& operand, const std::string& field); - Expr newObject(int64_t obj_id, std::string type_name, + Expr newObject(int64_t obj_id, const std::string& type_name, const std::vector& entries); Expr::CreateStruct::Entry newObjectField(int64_t field_id, const std::string& field, @@ -109,15 +109,16 @@ class SourceFactory { const Expr& accu_init, const Expr& condition, const Expr& step, const Expr& result); Expr newQuantifierExprForMacro(QuantifierKind kind, int64_t macro_id, - Expr* target, const std::vector& args); - Expr newFilterExprForMacro(int64_t macro_id, Expr* target, + const Expr& target, + const std::vector& args); + Expr newFilterExprForMacro(int64_t macro_id, const Expr& target, const std::vector& args); Expr newList(int64_t list_id, const std::vector& elems); Expr newListForMacro(int64_t macro_id, const std::vector& elems); Expr newMap(int64_t map_id, const std::vector& entries); - Expr newMapForMacro(int64_t macro_id, Expr* target, + Expr newMapForMacro(int64_t macro_id, const Expr& target, const std::vector& args); Expr::CreateStruct::Entry newMapEntry(int64_t entry_id, const Expr& key, const Expr& value); diff --git a/parser/visitor.cc b/parser/visitor.cc index 8a4f1930e..9851c3155 100644 --- a/parser/visitor.cc +++ b/parser/visitor.cc @@ -490,35 +490,40 @@ std::string ParserVisitor::errorMessage() const { return sf_->errorMessage(description_, expression_); } -Expr ParserVisitor::globalCallOrMacro(int64_t expr_id, std::string function, - std::vector args) { +Expr ParserVisitor::globalCallOrMacro(int64_t expr_id, + const std::string& function, + const std::vector& args) { Expr macro_expr; - if (expandMacro(expr_id, function, nullptr, args, ¯o_expr)) { + if (expandMacro(expr_id, function, Expr::default_instance(), args, + ¯o_expr)) { return macro_expr; } return sf_->newGlobalCall(expr_id, function, args); } -Expr ParserVisitor::receiverCallOrMacro(int64_t expr_id, std::string function, - Expr target, std::vector args) { +Expr ParserVisitor::receiverCallOrMacro(int64_t expr_id, + const std::string& function, + const Expr& target, + const std::vector& args) { Expr macro_expr; - if (expandMacro(expr_id, function, &target, args, ¯o_expr)) { + if (expandMacro(expr_id, function, target, args, ¯o_expr)) { return macro_expr; } return sf_->newReceiverCall(expr_id, function, target, args); } -bool ParserVisitor::expandMacro(int64_t expr_id, std::string function, - Expr* target, std::vector args, +bool ParserVisitor::expandMacro(int64_t expr_id, const std::string& function, + const Expr& target, + const std::vector& args, Expr* macro_expr) { std::string macro_key = absl::StrFormat("%s:%d:%s", function, args.size(), - target ? "true" : "false"); + target.id() != 0 ? "true" : "false"); auto m = macros_.find(macro_key); if (m == macros_.end()) { - std::string var_arg_macro_key = - absl::StrFormat("%s:*:%s", function, target ? "true" : "false"); + std::string var_arg_macro_key = absl::StrFormat( + "%s:*:%s", function, target.id() != 0 ? "true" : "false"); m = macros_.find(var_arg_macro_key); if (m == macros_.end()) { return false; @@ -527,7 +532,7 @@ bool ParserVisitor::expandMacro(int64_t expr_id, std::string function, Expr expr = m->second.expand(sf_, expr_id, target, args); if (expr.expr_kind_case() != Expr::EXPR_KIND_NOT_SET) { - macro_expr->CopyFrom(expr); + *macro_expr = std::move(expr); return true; } return false; diff --git a/parser/visitor.h b/parser/visitor.h index 7f5cb738a..b5a819114 100644 --- a/parser/visitor.h +++ b/parser/visitor.h @@ -89,12 +89,13 @@ class ParserVisitor : public ::cel_grammar::CelBaseVisitor, std::string errorMessage() const; private: - Expr globalCallOrMacro(int64_t expr_id, std::string function, - std::vector args); - Expr receiverCallOrMacro(int64_t expr_id, std::string function, Expr target, - std::vector args); - bool expandMacro(int64_t expr_id, std::string function, Expr* target, - std::vector args, Expr* macro_expr); + Expr globalCallOrMacro(int64_t expr_id, const std::string& function, + const std::vector& args); + Expr receiverCallOrMacro(int64_t expr_id, const std::string& function, + const Expr& target, const std::vector& args); + bool expandMacro(int64_t expr_id, const std::string& function, + const Expr& target, const std::vector& args, + Expr* macro_expr); std::string unquote(antlr4::ParserRuleContext* ctx, const std::string& s, bool is_bytes); std::string extractQualifiedName(antlr4::ParserRuleContext* ctx, diff --git a/testutil/BUILD b/testutil/BUILD index 450474c48..eb71cff9b 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -49,6 +49,7 @@ cc_library( cc_library( name = "test_data_io", + testonly = True, srcs = [ "test_data_io.cc", ], @@ -76,6 +77,7 @@ cc_library( # third_party/cel/spec/testdata/unique_values.textpb cc_binary( name = "test_data_gen", + testonly = True, srcs = [ "test_data_gen.cc", ], @@ -110,6 +112,7 @@ cc_test( cc_library( name = "util", + testonly = True, hdrs = [ "util.h", ],