From 39616cd5fcda1375fedbb6955dccc992ccf49f20 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 6 Jan 2021 17:44:48 -0500 Subject: [PATCH 1/6] internal changes PiperOrigin-RevId: 350431125 --- common/type.cc | 4 ++-- eval/eval/create_struct_step.cc | 2 +- eval/public/activation.h | 2 +- eval/public/structs/cel_proto_wrapper.cc | 2 +- parser/parser.cc | 2 +- parser/source_factory.h | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/common/type.cc b/common/type.cc index ad9fa0ddc..fb8097f5c 100644 --- a/common/type.cc +++ b/common/type.cc @@ -72,7 +72,7 @@ UnrecognizedType::UnrecognizedType(absl::string_view full_name) : string_rep_(absl::StrCat("type(\"", full_name, "\")")), hash_code_(internal::Hash(full_name)) { assert(google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - std::string(full_name)) == nullptr); + std.data()::string(full_name)) == nullptr); } absl::string_view UnrecognizedType::full_name() const { @@ -89,7 +89,7 @@ Type::Type(const std::string& full_name) auto obj_desc = google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - full_name); + full_name.data()); if (obj_desc != nullptr) { data_ = ObjectType(obj_desc); return; diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index abd92e8d8..1f93447ee 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -280,7 +280,7 @@ absl::StatusOr> CreateCreateStructStep( const Descriptor* desc = DescriptorPool::generated_pool()->FindMessageTypeByName( - create_struct_expr->message_name()); + create_struct_expr.data()->message_name()); if (desc == nullptr) { return absl::InvalidArgumentError( diff --git a/eval/public/activation.h b/eval/public/activation.h index a6346699e..1c7271b9f 100644 --- a/eval/public/activation.h +++ b/eval/public/activation.h @@ -84,7 +84,7 @@ class Activation : public BaseActivation { google::protobuf::Arena* arena) const override; bool IsPathUnknown(absl::string_view path) const override { - return google::protobuf::util::FieldMaskUtil::IsPathInFieldMask(std::string(path), + return google::protobuf::util::FieldMaskUtil::IsPathInFieldMask(std.data()::string(path), unknown_paths_); } diff --git a/eval/public/structs/cel_proto_wrapper.cc b/eval/public/structs/cel_proto_wrapper.cc index 78c29f463..77d318c27 100644 --- a/eval/public/structs/cel_proto_wrapper.cc +++ b/eval/public/structs/cel_proto_wrapper.cc @@ -155,7 +155,7 @@ CelValue ValueFromMessage(const Any* any_value, Arena* arena) { std::string full_name = std::string(type_url.substr(pos + 1)); const Descriptor* nested_descriptor = - DescriptorPool::generated_pool()->FindMessageTypeByName(full_name); + DescriptorPool::generated_pool()->FindMessageTypeByName(full_name.data()); if (nested_descriptor == nullptr) { // Descriptor not found for the type diff --git a/parser/parser.cc b/parser/parser.cc index 35b804f42..6d9b6ebbe 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -7,7 +7,7 @@ #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" #include "parser/source_factory.h" #include "parser/visitor.h" -#include "antlr4-runtime.h" +#include "third_party/java/antlr4/v4_7_1/Cpp/src/antlr4-runtime.h" namespace google { namespace api { diff --git a/parser/source_factory.h b/parser/source_factory.h index 79d766f45..ac89b95a3 100644 --- a/parser/source_factory.h +++ b/parser/source_factory.h @@ -8,7 +8,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/types/optional.h" #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" -#include "antlr4-runtime.h" +#include "third_party/java/antlr4/v4_7_1/Cpp/src/antlr4-runtime.h" namespace google { namespace api { From f34755776c0485b4a5660ab4392dffa7be4f7c60 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Mon, 18 Jan 2021 17:12:57 -0500 Subject: [PATCH 2/6] Apply clang-tidy suggestions PiperOrigin-RevId: 352452757 --- common/value.h | 2 +- eval/eval/attribute_trail.h | 5 +- eval/eval/const_value_step.cc | 11 ++--- eval/eval/container_access_step.cc | 4 +- eval/eval/create_list_step.cc | 5 +- eval/eval/create_struct_step.cc | 10 ++-- eval/eval/function_step.cc | 10 ++-- eval/eval/ident_step.cc | 6 +-- eval/eval/jump_step.cc | 16 ++----- eval/eval/logic_step.cc | 11 +---- eval/eval/select_step.cc | 3 +- eval/eval/ternary_step.cc | 5 +- eval/public/cel_expression.h | 4 ++ parser/BUILD | 1 + parser/balancer.cc | 10 ++-- parser/macro.cc | 28 +++++------ parser/macro.h | 19 ++++---- parser/parser.cc | 43 +++++++++-------- parser/parser.h | 8 ++-- parser/source_factory.cc | 74 +++++++++++++++--------------- parser/source_factory.h | 19 ++++---- parser/visitor.cc | 29 +++++++----- parser/visitor.h | 13 +++--- 23 files changed, 164 insertions(+), 172 deletions(-) diff --git a/common/value.h b/common/value.h index 777021a4b..e0b97ba37 100644 --- a/common/value.h +++ b/common/value.h @@ -330,7 +330,7 @@ class Container : public SharedValue { } template static Value GetValue(V&& value) { - return Value::From(std::move(value)); + return Value::From(std::forward(value)); } private: diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h index c2aefc6cb..48874726d 100644 --- a/eval/eval/attribute_trail.h +++ b/eval/eval/attribute_trail.h @@ -32,7 +32,7 @@ class AttributeTrail { AttributeTrail() : attribute_(nullptr) {} AttributeTrail(google::api::expr::v1alpha1::Expr root, google::protobuf::Arena* arena) : AttributeTrail(google::protobuf::Arena::Create( - arena, root, std::vector())) {} + arena, std::move(root), std::vector())) {} // Creates AttributeTrail with attribute path incremented by "qualifier". AttributeTrail Step(CelAttributeQualifier qualifier, @@ -52,7 +52,8 @@ class AttributeTrail { bool empty() const { return !attribute_; } private: - AttributeTrail(const CelAttribute* attribute) : attribute_(attribute) {} + explicit AttributeTrail(const CelAttribute* attribute) + : attribute_(attribute) {} const CelAttribute* attribute_; }; } // namespace runtime diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 5cdc216c6..3a0a0fe08 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -75,19 +75,14 @@ absl::optional ConvertConstant(const Constant* const_expr) { absl::StatusOr> CreateConstValueStep( CelValue value, int64_t expr_id, bool comes_from_ast) { - std::unique_ptr step = - absl::make_unique(value, expr_id, comes_from_ast); - return std::move(step); + return absl::make_unique(value, expr_id, comes_from_ast); } // Factory method for Constant(Enum value) - based Execution step absl::StatusOr> CreateConstValueStep( const google::protobuf::EnumValueDescriptor* value_descriptor, int64_t expr_id) { - CelValue value = CelValue::CreateInt64(value_descriptor->number()); - - std::unique_ptr step = - absl::make_unique(value, expr_id, false); - return std::move(step); + return absl::make_unique( + CelValue::CreateInt64(value_descriptor->number()), expr_id, false); } } // namespace runtime diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index aeb2499f9..619a065d5 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -156,9 +156,7 @@ absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { // Factory method for Select - based Execution step absl::StatusOr> CreateContainerAccessStep( const google::api::expr::v1alpha1::Expr::Call*, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(expr_id); - return std::move(step); + return absl::make_unique(expr_id); } } // namespace runtime diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index f9da18357..21be321a9 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -65,9 +65,8 @@ absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateCreateListStep( const google::api::expr::v1alpha1::Expr::CreateList* create_list_expr, int64_t expr_id) { - std::unique_ptr step = absl::make_unique( - expr_id, create_list_expr->elements_size()); - return std::move(step); + return absl::make_unique(expr_id, + create_list_expr->elements_size()); } } // namespace runtime diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 1f93447ee..63d0621dd 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -1,5 +1,7 @@ #include "eval/eval/create_struct_step.h" +#include + #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -303,12 +305,12 @@ absl::StatusOr> CreateCreateStructStep( entries.push_back({field_desc}); } - return absl::WrapUnique( - new CreateStructStepForMessage(expr_id, desc, std::move(entries))); + return std::make_unique(expr_id, desc, + std::move(entries)); } else { // Make map-creating step. - return absl::WrapUnique(new CreateStructStepForMap( - expr_id, create_struct_expr->entries_size())); + return std::make_unique( + expr_id, create_struct_expr->entries_size()); } } diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 989c7daee..0ce103762 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -293,9 +293,8 @@ absl::StatusOr> CreateFunctionStep( function_registry.FindLazyOverloads(name, receiver_style, args); if (!lazy_overloads.empty()) { - std::unique_ptr step = absl::make_unique( - name, num_args, receiver_style, lazy_overloads, expr_id); - return std::move(step); + return absl::make_unique(name, num_args, receiver_style, + lazy_overloads, expr_id); } auto overloads = function_registry.FindOverloads(name, receiver_style, args); @@ -307,9 +306,8 @@ absl::StatusOr> CreateFunctionStep( "No overloads provided for FunctionStep creation"))); } - std::unique_ptr step = absl::make_unique( - std::move(overloads), name, num_args, expr_id); - return std::move(step); + return absl::make_unique(std::move(overloads), name, + num_args, expr_id); } } // namespace runtime diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index 9c99621bd..420aebaee 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -46,7 +46,7 @@ void IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, if (frame->enable_missing_attribute_errors() || frame->enable_unknowns()) { google::api::expr::v1alpha1::Expr expr; expr.mutable_ident_expr()->set_name(name_); - *trail = AttributeTrail(expr, frame->arena()); + *trail = AttributeTrail(std::move(expr), frame->arena()); } if (frame->enable_missing_attribute_errors() && !name_.empty() && @@ -102,9 +102,7 @@ absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateIdentStep( const google::api::expr::v1alpha1::Expr::Ident* ident_expr, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(ident_expr->name(), expr_id); - return std::move(step); + return absl::make_unique(ident_expr->name(), expr_id); } } // namespace runtime diff --git a/eval/eval/jump_step.cc b/eval/eval/jump_step.cc index 5c85a645b..311ab6154 100644 --- a/eval/eval/jump_step.cc +++ b/eval/eval/jump_step.cc @@ -100,19 +100,14 @@ class BoolCheckJumpStep : public JumpStepBase { absl::StatusOr> CreateCondJumpStep( bool jump_condition, bool leave_on_stack, absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = absl::make_unique( - jump_condition, leave_on_stack, jump_offset, expr_id); - - return std::move(step); + return absl::make_unique(jump_condition, leave_on_stack, + jump_offset, expr_id); } // Factory method for Jump step. absl::StatusOr> CreateJumpStep( absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(jump_offset, expr_id); - - return std::move(step); + return absl::make_unique(jump_offset, expr_id); } // Factory method for Conditional Jump step. @@ -120,10 +115,7 @@ absl::StatusOr> CreateJumpStep( // If this value is an error or unknown, a jump is performed. absl::StatusOr> CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(jump_offset, expr_id); - - return std::move(step); + return absl::make_unique(jump_offset, expr_id); } // TODO(issues/41) Make sure Unknowns are properly supported by ternary diff --git a/eval/eval/logic_step.cc b/eval/eval/logic_step.cc index ed4da4700..045116ac9 100644 --- a/eval/eval/logic_step.cc +++ b/eval/eval/logic_step.cc @@ -111,21 +111,14 @@ absl::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { } // namespace - // Factory method for "And" Execution step absl::StatusOr> CreateAndStep(int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(LogicalOpStep::OpType::AND, expr_id); - - return std::move(step); + return absl::make_unique(LogicalOpStep::OpType::AND, expr_id); } // Factory method for "Or" Execution step absl::StatusOr> CreateOrStep(int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(LogicalOpStep::OpType::OR, expr_id); - - return std::move(step); + return absl::make_unique(LogicalOpStep::OpType::OR, expr_id); } } // namespace runtime diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 507c13e54..cc83b4167 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -229,9 +229,8 @@ absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateSelectStep( const google::api::expr::v1alpha1::Expr::Select* select_expr, int64_t expr_id, absl::string_view select_path) { - std::unique_ptr step = absl::make_unique( + return absl::make_unique( select_expr->field(), select_expr->test_only(), expr_id, select_path); - return std::move(step); } } // namespace runtime diff --git a/eval/eval/ternary_step.cc b/eval/eval/ternary_step.cc index a52430ad3..420cf10e5 100644 --- a/eval/eval/ternary_step.cc +++ b/eval/eval/ternary_step.cc @@ -69,10 +69,7 @@ absl::Status TernaryStep::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateTernaryStep( int64_t expr_id) { - std::unique_ptr step = - absl::make_unique(expr_id); - - return step; + return absl::make_unique(expr_id); } } // namespace runtime diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 524051cc1..79037630b 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -121,6 +121,10 @@ class CelExpressionBuilder { CelFunctionRegistry* GetRegistry() const { return registry_.get(); } // Enums registered with the builder. + // + // TODO(issues/105): this should not be std::set as the ordering of pointers + // is inconsistent across processes and should be absl::node_hash_map or + // absl::flat_hash_map const std::set& resolvable_enums() const { return resolvable_enums_; } diff --git a/parser/BUILD b/parser/BUILD index f6378ae37..3f2be39f1 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -97,6 +97,7 @@ cc_library( ":cel_cc_parser", "//common:operators", "@antlr4_runtimes//:cpp", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", diff --git a/parser/balancer.cc b/parser/balancer.cc index 88cfbd6d9..6cd85a2e5 100644 --- a/parser/balancer.cc +++ b/parser/balancer.cc @@ -9,10 +9,13 @@ namespace parser { ExpressionBalancer::ExpressionBalancer(std::shared_ptr sf, std::string function, Expr expr) - : sf_(sf), function_(function), terms_{expr}, ops_{} {} + : sf_(std::move(sf)), + function_(std::move(function)), + terms_{std::move(expr)}, + ops_{} {} void ExpressionBalancer::addTerm(int64_t op, Expr term) { - terms_.push_back(term); + terms_.push_back(std::move(term)); ops_.push_back(op); } @@ -39,7 +42,8 @@ Expr ExpressionBalancer::balancedTree(int lo, int hi) { } else { right = balancedTree(mid + 1, hi); } - return sf_->newGlobalCall(ops_[mid], function_, {left, right}); + return sf_->newGlobalCall(ops_[mid], function_, + {std::move(left), std::move(right)}); } } // namespace parser diff --git a/parser/macro.cc b/parser/macro.cc index 7eb8559c1..c7e72898f 100644 --- a/parser/macro.cc +++ b/parser/macro.cc @@ -26,8 +26,8 @@ std::vector Macro::AllMacros() { // The macro "has(m.f)" which tests the presence of a field, avoiding the // need to specify the field as a string. Macro(CelOperator::HAS, 1, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { if (!args.empty() && args[0].has_select_expr()) { const auto& sel_expr = args[0].select_expr(); return sf->newPresenceTestForMacro(macro_id, sel_expr.operand(), @@ -43,8 +43,8 @@ std::vector Macro::AllMacros() { // in range the predicate holds. Macro( CelOperator::ALL, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newQuantifierExprForMacro(SourceFactory::QUANTIFIER_ALL, macro_id, target, args); }, @@ -54,8 +54,8 @@ std::vector Macro::AllMacros() { // one element in range the predicate holds. Macro( CelOperator::EXISTS, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newQuantifierExprForMacro( SourceFactory::QUANTIFIER_EXISTS, macro_id, target, args); }, @@ -65,8 +65,8 @@ std::vector Macro::AllMacros() { // exactly one element in range the predicate holds. Macro( CelOperator::EXISTS_ONE, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newQuantifierExprForMacro( SourceFactory::QUANTIFIER_EXISTS_ONE, macro_id, target, args); }, @@ -77,8 +77,8 @@ std::vector Macro::AllMacros() { // the range. Macro( CelOperator::MAP, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newMapForMacro(macro_id, target, args); }, /* receiver style*/ true), @@ -89,8 +89,8 @@ std::vector Macro::AllMacros() { // variables are filtered out. Macro( CelOperator::MAP, 3, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newMapForMacro(macro_id, target, args); }, /* receiver style*/ true), @@ -100,8 +100,8 @@ std::vector Macro::AllMacros() { // predicate is false. Macro( CelOperator::FILTER, 2, - [](std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { + [](const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { return sf->newFilterExprForMacro(macro_id, target, args); }, /* receiver style*/ true), diff --git a/parser/macro.h b/parser/macro.h index 7a30e725c..6277593da 100644 --- a/parser/macro.h +++ b/parser/macro.h @@ -4,6 +4,7 @@ #include #include #include +#include #include "google/api/expr/v1alpha1/syntax.pb.h" @@ -19,11 +20,11 @@ class SourceFactory; // MacroExpander converts the target and args of a function call that matches a // Macro. // -// Note: when the Macros.IsReceiverStyle() is true, the target argument will be -// empty. +// Note: when the Macros.IsReceiverStyle() is true, the target argument will +// be Expr::default_instance(). using MacroExpander = - std::function sf, int64_t macro_id, Expr*, - const std::vector&)>; + std::function& sf, int64_t macro_id, + const Expr&, const std::vector&)>; // Macro interface for describing the function signature to match and the // MacroExpander to apply. @@ -39,7 +40,7 @@ class Macro { receiver_style_(receiver_style), var_arg_style_(false), arg_count_(arg_count), - expander_(expander) {} + expander_(std::move(expander)) {} Macro(const std::string& function, MacroExpander expander, bool receiver_style = false) @@ -47,7 +48,7 @@ class Macro { receiver_style_(receiver_style), var_arg_style_(true), arg_count_(0), - expander_(expander) {} + expander_(std::move(expander)) {} // Function name to match. std::string function() const { return function_; } @@ -73,9 +74,9 @@ class Macro { // parsed call signature. const MacroExpander& expander() const { return expander_; } - Expr expand(std::shared_ptr sf, int64_t macro_id, Expr* target, - const std::vector& args) { - return expander_(sf, macro_id, target, args); + Expr expand(const std::shared_ptr& sf, int64_t macro_id, + const Expr& target, const std::vector& args) { + return expander_(std::move(sf), macro_id, target, args); } static std::vector AllMacros(); diff --git a/parser/parser.cc b/parser/parser.cc index 6d9b6ebbe..9b02dabc1 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -14,23 +14,27 @@ namespace api { namespace expr { namespace parser { -using antlr4::ANTLRInputStream; -using antlr4::CommonTokenStream; -using antlr4::ParseCancellationException; -using antlr4::ParserRuleContext; +using ::antlr4::ANTLRInputStream; +using ::antlr4::CommonTokenStream; +using ::antlr4::ParseCancellationException; +using ::antlr4::ParserRuleContext; -using antlr4::tree::ErrorNode; -using antlr4::tree::TerminalNode; +using ::antlr4::tree::ErrorNode; +using ::antlr4::tree::ParseTreeListener; +using ::antlr4::tree::TerminalNode; -using google::api::expr::v1alpha1::Expr; -using google::api::expr::v1alpha1::ParsedExpr; +using ::google::api::expr::v1alpha1::Expr; +using ::google::api::expr::v1alpha1::ParsedExpr; + +using ::cel_grammar::CelLexer; +using ::cel_grammar::CelParser; namespace { // ExprRecursionListener extends the standard ANTLR CelParser to ensure that // recursive entries into the 'expr' rule are limited to a configurable depth so // as to prevent stack overflows. -class ExprRecursionListener : public ::antlr4::tree::ParseTreeListener { +class ExprRecursionListener : public ParseTreeListener { public: ExprRecursionListener( const int max_recursion_depth = kDefaultMaxRecursionDepth) @@ -50,7 +54,7 @@ void ExprRecursionListener::enterEveryRule(ParserRuleContext* ctx) { // Throw a ParseCancellationException since the parsing would otherwise // continue if this were treated as a syntax error and the problem would // continue to manifest. - if (ctx->getRuleIndex() == ::cel_grammar::CelParser::RuleExpr) { + if (ctx->getRuleIndex() == CelParser::RuleExpr) { if (recursion_depth_ >= max_recursion_depth_) { throw ParseCancellationException( absl::StrFormat("Expression recursion limit exceeded. limit: %d", @@ -61,7 +65,7 @@ void ExprRecursionListener::enterEveryRule(ParserRuleContext* ctx) { } void ExprRecursionListener::exitEveryRule(ParserRuleContext* ctx) { - if (ctx->getRuleIndex() == ::cel_grammar::CelParser::RuleExpr) { + if (ctx->getRuleIndex() == CelParser::RuleExpr) { recursion_depth_--; } } @@ -91,9 +95,9 @@ absl::StatusOr EnrichedParse( const std::string& expression, const std::vector& macros, const std::string& description, const int max_recursion_depth) { ANTLRInputStream input(expression); - ::cel_grammar::CelLexer lexer(&input); + CelLexer lexer(&input); CommonTokenStream tokens(&lexer); - ::cel_grammar::CelParser parser(&tokens); + CelParser parser(&tokens); ExprRecursionListener listener(max_recursion_depth); ParserVisitor visitor(description, expression, max_recursion_depth, macros); @@ -107,12 +111,12 @@ absl::StatusOr EnrichedParse( // std::shared_ptr error_strategy(new BailErrorStrategy()); // parser.setErrorHandler(error_strategy); - ::cel_grammar::CelParser::StartContext* root; + CelParser::StartContext* root; try { root = parser.start(); - } catch (ParseCancellationException& e) { + } catch (const ParseCancellationException& e) { return absl::CancelledError(e.what()); - } catch (std::exception& e) { + } catch (const std::exception& e) { return absl::AbortedError(e.what()); } @@ -124,10 +128,11 @@ absl::StatusOr EnrichedParse( // root is deleted as part of the parser context ParsedExpr parsed_expr; - parsed_expr.mutable_expr()->CopyFrom(expr); - parsed_expr.mutable_source_info()->CopyFrom(visitor.sourceInfo()); + *(parsed_expr.mutable_expr()) = std::move(expr); auto enriched_source_info = visitor.enrichedSourceInfo(); - return VerboseParsedExpr(parsed_expr, enriched_source_info); + *(parsed_expr.mutable_source_info()) = visitor.sourceInfo(); + return VerboseParsedExpr(std::move(parsed_expr), + std::move(enriched_source_info)); } } // namespace parser diff --git a/parser/parser.h b/parser/parser.h index 46bb561d1..7227c8fed 100644 --- a/parser/parser.h +++ b/parser/parser.h @@ -16,10 +16,10 @@ constexpr int kDefaultMaxRecursionDepth = 250; class VerboseParsedExpr { public: - VerboseParsedExpr(const google::api::expr::v1alpha1::ParsedExpr& parsed_expr, - const EnrichedSourceInfo& enriched_source_info) - : parsed_expr_(parsed_expr), - enriched_source_info_(enriched_source_info) {} + VerboseParsedExpr(google::api::expr::v1alpha1::ParsedExpr parsed_expr, + EnrichedSourceInfo enriched_source_info) + : parsed_expr_(std::move(parsed_expr)), + enriched_source_info_(std::move(enriched_source_info)) {} const google::api::expr::v1alpha1::ParsedExpr& parsed_expr() const { return parsed_expr_; diff --git a/parser/source_factory.cc b/parser/source_factory.cc index 3573f5289..3052b2407 100644 --- a/parser/source_factory.cc +++ b/parser/source_factory.cc @@ -4,6 +4,7 @@ #include #include "google/protobuf/struct.pb.h" +#include "absl/container/flat_hash_set.h" #include "absl/memory/memory.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" @@ -36,9 +37,10 @@ SourceFactory::SourceFactory(const std::string& expression) int64_t SourceFactory::id(const antlr4::Token* token) { int64_t new_id = next_id_; positions_.emplace( - new_id, SourceLocation{(int32_t)token->getLine(), - (int32_t)token->getCharPositionInLine(), - (int32_t)token->getStopIndex(), line_offsets_}); + new_id, + SourceLocation{static_cast(token->getLine()), + static_cast(token->getCharPositionInLine()), + static_cast(token->getStopIndex()), line_offsets_}); next_id_ += 1; return new_id; } @@ -86,9 +88,8 @@ Expr SourceFactory::newGlobalCall(int64_t id, const std::string& function, Expr expr = newExpr(id); auto call_expr = expr.mutable_call_expr(); call_expr->set_function(function); - std::for_each(args.begin(), args.end(), [&call_expr](const Expr& e) { - call_expr->add_args()->CopyFrom(e); - }); + std::for_each(args.begin(), args.end(), + [&call_expr](const Expr& e) { *call_expr->add_args() = e; }); return expr; } @@ -99,15 +100,14 @@ Expr SourceFactory::newGlobalCallForMacro(int64_t macro_id, } Expr SourceFactory::newReceiverCall(int64_t id, const std::string& function, - Expr& target, + const Expr& target, const std::vector& args) { Expr expr = newExpr(id); auto call_expr = expr.mutable_call_expr(); call_expr->set_function(function); - call_expr->mutable_target()->CopyFrom(target); - std::for_each(args.begin(), args.end(), [&call_expr](const Expr& e) { - call_expr->add_args()->CopyFrom(e); - }); + *call_expr->mutable_target() = target; + std::for_each(args.begin(), args.end(), + [&call_expr](const Expr& e) { *call_expr->add_args() = e; }); return expr; } @@ -130,7 +130,7 @@ Expr SourceFactory::newSelect( const std::string& field) { Expr expr = newExpr(ctx->op); auto select_expr = expr.mutable_select_expr(); - select_expr->mutable_operand()->CopyFrom(operand); + *select_expr->mutable_operand() = operand; select_expr->set_field(field); return expr; } @@ -139,14 +139,14 @@ Expr SourceFactory::newPresenceTestForMacro(int64_t macro_id, const Expr& operan const std::string& field) { Expr expr = newExpr(nextMacroId(macro_id)); auto select_expr = expr.mutable_select_expr(); - select_expr->mutable_operand()->CopyFrom(operand); + *select_expr->mutable_operand() = operand; select_expr->set_field(field); select_expr->set_test_only(true); return expr; } Expr SourceFactory::newObject( - int64_t obj_id, std::string type_name, + int64_t obj_id, const std::string& type_name, const std::vector& entries) { auto expr = newExpr(obj_id); auto struct_expr = expr.mutable_struct_expr(); @@ -163,7 +163,7 @@ Expr::CreateStruct::Entry SourceFactory::newObjectField( Expr::CreateStruct::Entry entry; entry.set_id(field_id); entry.set_field_key(field); - entry.mutable_value()->CopyFrom(value); + *entry.mutable_value() = value; return entry; } @@ -176,12 +176,12 @@ Expr SourceFactory::newComprehension(int64_t id, const std::string& iter_var, Expr expr = newExpr(id); auto comp_expr = expr.mutable_comprehension_expr(); comp_expr->set_iter_var(iter_var); - comp_expr->mutable_iter_range()->CopyFrom(iter_range); + *comp_expr->mutable_iter_range() = iter_range; comp_expr->set_accu_var(accu_var); - comp_expr->mutable_accu_init()->CopyFrom(accu_init); - comp_expr->mutable_loop_condition()->CopyFrom(condition); - comp_expr->mutable_loop_step()->CopyFrom(step); - comp_expr->mutable_result()->CopyFrom(result); + *comp_expr->mutable_accu_init() = accu_init; + *comp_expr->mutable_loop_condition() = condition; + *comp_expr->mutable_loop_step() = step; + *comp_expr->mutable_result() = result; return expr; } @@ -197,14 +197,13 @@ Expr SourceFactory::foldForMacro(int64_t macro_id, const std::string& iter_var, Expr SourceFactory::newList(int64_t list_id, const std::vector& elems) { auto expr = newExpr(list_id); auto list_expr = expr.mutable_list_expr(); - std::for_each(elems.begin(), elems.end(), [list_expr](const Expr& e) { - list_expr->add_elements()->CopyFrom(e); - }); + std::for_each(elems.begin(), elems.end(), + [list_expr](const Expr& e) { *list_expr->add_elements() = e; }); return expr; } Expr SourceFactory::newQuantifierExprForMacro( - SourceFactory::QuantifierKind kind, int64_t macro_id, Expr* target, + SourceFactory::QuantifierKind kind, int64_t macro_id, const Expr& target, const std::vector& args) { if (args.empty()) { return Expr(); @@ -264,11 +263,11 @@ Expr SourceFactory::newQuantifierExprForMacro( break; } } - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, + return foldForMacro(macro_id, v, target, AccumulatorName, init, condition, step, result); } -Expr SourceFactory::newFilterExprForMacro(int64_t macro_id, Expr* target, +Expr SourceFactory::newFilterExprForMacro(int64_t macro_id, const Expr& target, const std::vector& args) { if (args.empty()) { return Expr(); @@ -291,7 +290,7 @@ Expr SourceFactory::newFilterExprForMacro(int64_t macro_id, Expr* target, {accu_expr, newListForMacro(macro_id, {args[0]})}); step = newGlobalCallForMacro(macro_id, CelOperator::CONDITIONAL, {filter, step, accu_expr}); - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, + return foldForMacro(macro_id, v, target, AccumulatorName, init, condition, step, accu_expr); } @@ -311,7 +310,7 @@ Expr SourceFactory::newMap( return expr; } -Expr SourceFactory::newMapForMacro(int64_t macro_id, Expr* target, +Expr SourceFactory::newMapForMacro(int64_t macro_id, const Expr& target, const std::vector& args) { if (args.empty()) { return Expr(); @@ -345,7 +344,7 @@ Expr SourceFactory::newMapForMacro(int64_t macro_id, Expr* target, step = newGlobalCallForMacro(macro_id, CelOperator::CONDITIONAL, {filter, step, accu_expr}); } - return foldForMacro(macro_id, v, *target, AccumulatorName, init, condition, + return foldForMacro(macro_id, v, target, AccumulatorName, init, condition, step, accu_expr); } @@ -354,8 +353,8 @@ Expr::CreateStruct::Entry SourceFactory::newMapEntry(int64_t entry_id, const Expr& value) { Expr::CreateStruct::Entry entry; entry.set_id(entry_id); - entry.mutable_map_key()->CopyFrom(key); - entry.mutable_value()->CopyFrom(value); + *entry.mutable_map_key() = key; + *entry.mutable_value() = value; return entry; } @@ -508,12 +507,11 @@ std::string SourceFactory::errorMessage(const std::string& description, } bool SourceFactory::isReserved(const std::string& ident_name) { - static std::vector reserved_words = { - "as", "break", "const", "continue", "else", "false", "for", - "function", "if", "import", "in", "let", "loop", "package", - "namespace", "null", "return", "true", "var", "void", "while"}; - return std::find(reserved_words.begin(), reserved_words.end(), ident_name) != - reserved_words.end(); + static const auto* reserved_words = new absl::flat_hash_set( + {"as", "break", "const", "continue", "else", "false", "for", + "function", "if", "import", "in", "let", "loop", "package", + "namespace", "null", "return", "true", "var", "void", "while"}); + return reserved_words->find(ident_name) != reserved_words->end(); } google::api::expr::v1alpha1::SourceInfo SourceFactory::sourceInfo() const { @@ -537,7 +535,7 @@ EnrichedSourceInfo SourceFactory::enrichedSourceInfo() const { [&offset](const std::pair& loc) { offset.insert({loc.first, {loc.second.offset, loc.second.offset_end}}); }); - return EnrichedSourceInfo(offset); + return EnrichedSourceInfo(std::move(offset)); } void SourceFactory::calcLineOffsets(const std::string& expression) { diff --git a/parser/source_factory.h b/parser/source_factory.h index ac89b95a3..6047b0262 100644 --- a/parser/source_factory.h +++ b/parser/source_factory.h @@ -19,8 +19,8 @@ using google::api::expr::v1alpha1::Expr; class EnrichedSourceInfo { public: - EnrichedSourceInfo(const std::map>& offsets) - : offsets_(offsets) {} + EnrichedSourceInfo(std::map> offsets) + : offsets_(std::move(offsets)) {} const std::map>& offsets() const { return offsets_; @@ -56,7 +56,7 @@ class SourceFactory { struct Error { Error(std::string message, SourceLocation location) - : message(message), location(location) {} + : message(std::move(message)), location(location) {} std::string message; SourceLocation location; }; @@ -86,15 +86,15 @@ class SourceFactory { const std::vector& args); Expr newGlobalCallForMacro(int64_t macro_id, const std::string& function, const std::vector& args); - Expr newReceiverCall(int64_t id, const std::string& function, Expr& target, - const std::vector& args); + Expr newReceiverCall(int64_t id, const std::string& function, + const Expr& target, const std::vector& args); Expr newIdent(const antlr4::Token* token, const std::string& ident_name); Expr newIdentForMacro(int64_t macro_id, const std::string& ident_name); Expr newSelect(::cel_grammar::CelParser::SelectOrCallContext* ctx, Expr& operand, const std::string& field); Expr newPresenceTestForMacro(int64_t macro_id, const Expr& operand, const std::string& field); - Expr newObject(int64_t obj_id, std::string type_name, + Expr newObject(int64_t obj_id, const std::string& type_name, const std::vector& entries); Expr::CreateStruct::Entry newObjectField(int64_t field_id, const std::string& field, @@ -109,15 +109,16 @@ class SourceFactory { const Expr& accu_init, const Expr& condition, const Expr& step, const Expr& result); Expr newQuantifierExprForMacro(QuantifierKind kind, int64_t macro_id, - Expr* target, const std::vector& args); - Expr newFilterExprForMacro(int64_t macro_id, Expr* target, + const Expr& target, + const std::vector& args); + Expr newFilterExprForMacro(int64_t macro_id, const Expr& target, const std::vector& args); Expr newList(int64_t list_id, const std::vector& elems); Expr newListForMacro(int64_t macro_id, const std::vector& elems); Expr newMap(int64_t map_id, const std::vector& entries); - Expr newMapForMacro(int64_t macro_id, Expr* target, + Expr newMapForMacro(int64_t macro_id, const Expr& target, const std::vector& args); Expr::CreateStruct::Entry newMapEntry(int64_t entry_id, const Expr& key, const Expr& value); diff --git a/parser/visitor.cc b/parser/visitor.cc index 8a4f1930e..9851c3155 100644 --- a/parser/visitor.cc +++ b/parser/visitor.cc @@ -490,35 +490,40 @@ std::string ParserVisitor::errorMessage() const { return sf_->errorMessage(description_, expression_); } -Expr ParserVisitor::globalCallOrMacro(int64_t expr_id, std::string function, - std::vector args) { +Expr ParserVisitor::globalCallOrMacro(int64_t expr_id, + const std::string& function, + const std::vector& args) { Expr macro_expr; - if (expandMacro(expr_id, function, nullptr, args, ¯o_expr)) { + if (expandMacro(expr_id, function, Expr::default_instance(), args, + ¯o_expr)) { return macro_expr; } return sf_->newGlobalCall(expr_id, function, args); } -Expr ParserVisitor::receiverCallOrMacro(int64_t expr_id, std::string function, - Expr target, std::vector args) { +Expr ParserVisitor::receiverCallOrMacro(int64_t expr_id, + const std::string& function, + const Expr& target, + const std::vector& args) { Expr macro_expr; - if (expandMacro(expr_id, function, &target, args, ¯o_expr)) { + if (expandMacro(expr_id, function, target, args, ¯o_expr)) { return macro_expr; } return sf_->newReceiverCall(expr_id, function, target, args); } -bool ParserVisitor::expandMacro(int64_t expr_id, std::string function, - Expr* target, std::vector args, +bool ParserVisitor::expandMacro(int64_t expr_id, const std::string& function, + const Expr& target, + const std::vector& args, Expr* macro_expr) { std::string macro_key = absl::StrFormat("%s:%d:%s", function, args.size(), - target ? "true" : "false"); + target.id() != 0 ? "true" : "false"); auto m = macros_.find(macro_key); if (m == macros_.end()) { - std::string var_arg_macro_key = - absl::StrFormat("%s:*:%s", function, target ? "true" : "false"); + std::string var_arg_macro_key = absl::StrFormat( + "%s:*:%s", function, target.id() != 0 ? "true" : "false"); m = macros_.find(var_arg_macro_key); if (m == macros_.end()) { return false; @@ -527,7 +532,7 @@ bool ParserVisitor::expandMacro(int64_t expr_id, std::string function, Expr expr = m->second.expand(sf_, expr_id, target, args); if (expr.expr_kind_case() != Expr::EXPR_KIND_NOT_SET) { - macro_expr->CopyFrom(expr); + *macro_expr = std::move(expr); return true; } return false; diff --git a/parser/visitor.h b/parser/visitor.h index 7f5cb738a..b5a819114 100644 --- a/parser/visitor.h +++ b/parser/visitor.h @@ -89,12 +89,13 @@ class ParserVisitor : public ::cel_grammar::CelBaseVisitor, std::string errorMessage() const; private: - Expr globalCallOrMacro(int64_t expr_id, std::string function, - std::vector args); - Expr receiverCallOrMacro(int64_t expr_id, std::string function, Expr target, - std::vector args); - bool expandMacro(int64_t expr_id, std::string function, Expr* target, - std::vector args, Expr* macro_expr); + Expr globalCallOrMacro(int64_t expr_id, const std::string& function, + const std::vector& args); + Expr receiverCallOrMacro(int64_t expr_id, const std::string& function, + const Expr& target, const std::vector& args); + bool expandMacro(int64_t expr_id, const std::string& function, + const Expr& target, const std::vector& args, + Expr* macro_expr); std::string unquote(antlr4::ParserRuleContext* ctx, const std::string& s, bool is_bytes); std::string extractQualifiedName(antlr4::ParserRuleContext* ctx, From b19dad9ac2b9f5280c3f0c48b5d3acabee5523bb Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Fri, 22 Jan 2021 16:04:42 -0500 Subject: [PATCH 3/6] Internal change PiperOrigin-RevId: 353303190 --- parser/parser.cc | 2 +- parser/source_factory.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/parser/parser.cc b/parser/parser.cc index 9b02dabc1..40dce202e 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -7,7 +7,7 @@ #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" #include "parser/source_factory.h" #include "parser/visitor.h" -#include "third_party/java/antlr4/v4_7_1/Cpp/src/antlr4-runtime.h" +#include "antlr4-runtime.h" namespace google { namespace api { diff --git a/parser/source_factory.h b/parser/source_factory.h index 6047b0262..34617c951 100644 --- a/parser/source_factory.h +++ b/parser/source_factory.h @@ -8,7 +8,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/types/optional.h" #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" -#include "third_party/java/antlr4/v4_7_1/Cpp/src/antlr4-runtime.h" +#include "antlr4-runtime.h" namespace google { namespace api { From bca699e41525b6040988a7646b3517bb8fff9840 Mon Sep 17 00:00:00 2001 From: tswadell Date: Mon, 25 Jan 2021 13:42:42 -0500 Subject: [PATCH 4/6] Introduce CelTypeRegistry for tracking type identifiers with updates in type resolution. Type and enum constant values may be shadowed by variables with the same name provided within an Activation in order to preserve backward compatibility with potential existing usages of the library. This change makes it possible to find core CEL type names as identifiers, and if the `enable_qualified_type_identifiers` option is enabled, then qualified names which appear within Select expressions can be resolved to types which have either been registered with the CelTypeRegistry or to protobuf type names which have been linked into the binary. PiperOrigin-RevId: 353680095 --- conformance/BUILD | 10 +- conformance/server.cc | 31 +- eval/compiler/BUILD | 34 +++ eval/compiler/flat_expr_builder.cc | 209 +++++++------ eval/compiler/flat_expr_builder.h | 10 +- eval/compiler/flat_expr_builder_test.cc | 9 +- eval/compiler/qualified_reference_resolver.cc | 104 ++----- eval/compiler/qualified_reference_resolver.h | 5 +- .../qualified_reference_resolver_test.cc | 289 +++++++++++++++--- eval/compiler/resolver.cc | 164 ++++++++++ eval/compiler/resolver.h | 92 ++++++ eval/compiler/resolver_test.cc | 198 ++++++++++++ eval/eval/BUILD | 37 ++- eval/eval/const_value_step.cc | 1 - eval/eval/const_value_step.h | 5 - eval/eval/create_struct_step.cc | 20 +- eval/eval/create_struct_step.h | 15 +- eval/eval/create_struct_step_test.cc | 19 +- eval/eval/function_step.cc | 39 +-- eval/eval/function_step.h | 17 +- eval/eval/function_step_test.cc | 226 +++++--------- eval/eval/shadowable_value_step.cc | 45 +++ eval/eval/shadowable_value_step.h | 25 ++ eval/eval/shadowable_value_step_test.cc | 79 +++++ eval/public/BUILD | 30 +- eval/public/activation.h | 3 +- eval/public/cel_expr_builder_factory.cc | 2 + eval/public/cel_expression.h | 36 +-- eval/public/cel_function_registry.cc | 4 +- eval/public/cel_options.h | 8 + eval/public/cel_type_registry.cc | 75 +++++ eval/public/cel_type_registry.h | 74 +++++ eval/public/cel_type_registry_test.cc | 86 ++++++ eval/public/cel_value.h | 8 + eval/public/cel_value_test.cc | 8 +- 35 files changed, 1533 insertions(+), 484 deletions(-) create mode 100644 eval/compiler/resolver.cc create mode 100644 eval/compiler/resolver.h create mode 100644 eval/compiler/resolver_test.cc create mode 100644 eval/eval/shadowable_value_step.cc create mode 100644 eval/eval/shadowable_value_step.h create mode 100644 eval/eval/shadowable_value_step_test.cc create mode 100644 eval/public/cel_type_registry.cc create mode 100644 eval/public/cel_type_registry.h create mode 100644 eval/public/cel_type_registry_test.cc diff --git a/conformance/BUILD b/conformance/BUILD index 6d50aaed4..b538e25d4 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -19,9 +19,8 @@ ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/macros.textproto", "@com_google_cel_spec//tests/simple:testdata/namespace.textproto", "@com_google_cel_spec//tests/simple:testdata/plumbing.textproto", - # TODO(issues/92): Support for parse-only proto message creation within a container. - # "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", - # "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", + "@com_google_cel_spec//tests/simple:testdata/proto2.textproto", + "@com_google_cel_spec//tests/simple:testdata/proto3.textproto", "@com_google_cel_spec//tests/simple:testdata/string.textproto", "@com_google_cel_spec//tests/simple:testdata/timestamps.textproto", "@com_google_cel_spec//tests/simple:testdata/unknowns.textproto", @@ -93,16 +92,11 @@ cc_binary( "--skip_test=conversions/uint/double_nearest,double_nearest_int,double_half_away", # TODO(issues/82): Unexpected behavior when converting invalid bytes to string. "--skip_test=conversions/string/bytes_invalid", - # TODO(issues/83): Missing type() conversion functions - "--skip_test=conversions/type", # TODO(issues/96): Well-known type conversion support. "--skip_test=proto2/literal_wellknown", "--skip_test=proto3/literal_wellknown", "--skip_test=proto2/empty_field/wkt", "--skip_test=proto3/empty_field/wkt", - # TODO(issues/92): Support for parse-only proto message creation within a container. - "--skip_test=proto2/has/undefined", - "--skip_test=proto3/has/undefined", # Requires container support "--skip_test=namespace/namespace/self_eval_container_lookup_unchecked", "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", diff --git a/conformance/server.cc b/conformance/server.cc index 5022c668a..8b9ddac35 100644 --- a/conformance/server.cc +++ b/conformance/server.cc @@ -25,8 +25,6 @@ #include "proto/test/v1/proto3/test_all_types.pb.h" -using absl::Status; -using absl::StatusCode; using ::google::protobuf::Arena; using ::google::protobuf::util::JsonStringToMessage; using ::google::protobuf::util::MessageToJsonString; @@ -42,10 +40,10 @@ class ConformanceServiceImpl { public: explicit ConformanceServiceImpl(std::unique_ptr builder) : builder_(std::move(builder)), - proto2Tests_(&google::api::expr::test::v1::proto2::TestAllTypes:: - default_instance()), - proto3Tests_(&google::api::expr::test::v1::proto3::TestAllTypes:: - default_instance()) {} + proto2_tests_(&google::api::expr::test::v1::proto2::TestAllTypes:: + default_instance()), + proto3_tests_(&google::api::expr::test::v1::proto3::TestAllTypes:: + default_instance()) {} void Parse(const v1alpha1::ParseRequest* request, v1alpha1::ParseResponse* response) { @@ -63,7 +61,7 @@ class ConformanceServiceImpl { } else { google::api::expr::v1alpha1::ParsedExpr out; (out).MergeFrom(parse_status.value()); - response->mutable_parsed_expr()->CopyFrom(out); + *response->mutable_parsed_expr() = out; } } @@ -87,6 +85,7 @@ class ConformanceServiceImpl { google::api::expr::v1alpha1::SourceInfo source_info; google::api::expr::v1alpha1::Expr out; (out).MergeFrom(*expr); + builder_->set_container(request->container()); auto cel_expression_status = builder_->CreateExpression(&out, &source_info); if (!cel_expression_status.ok()) { @@ -144,13 +143,14 @@ class ConformanceServiceImpl { private: std::unique_ptr builder_; - const google::api::expr::test::v1::proto2::TestAllTypes* proto2Tests_; - const google::api::expr::test::v1::proto3::TestAllTypes* proto3Tests_; + const google::api::expr::test::v1::proto2::TestAllTypes* proto2_tests_; + const google::api::expr::test::v1::proto3::TestAllTypes* proto3_tests_; }; int RunServer(bool optimize) { google::protobuf::Arena arena; InterpreterOptions options; + options.enable_qualified_type_identifiers = true; if (optimize) { std::cerr << "Enabling optimizations" << std::endl; @@ -160,14 +160,15 @@ int RunServer(bool optimize) { std::unique_ptr builder = CreateCelExpressionBuilder(options); - builder->AddResolvableEnum( + auto type_registry = builder->GetTypeRegistry(); + type_registry->Register( google::api::expr::test::v1::proto2::GlobalEnum_descriptor()); - builder->AddResolvableEnum( + type_registry->Register( google::api::expr::test::v1::proto3::GlobalEnum_descriptor()); - builder->AddResolvableEnum(google::api::expr::test::v1::proto2::TestAllTypes:: - NestedEnum_descriptor()); - builder->AddResolvableEnum(google::api::expr::test::v1::proto3::TestAllTypes:: - NestedEnum_descriptor()); + type_registry->Register(google::api::expr::test::v1::proto2::TestAllTypes:: + NestedEnum_descriptor()); + type_registry->Register(google::api::expr::test::v1::proto3::TestAllTypes:: + NestedEnum_descriptor()); auto register_status = RegisterBuiltinFunctions(builder->GetRegistry()); if (!register_status.ok()) { std::cerr << "Failed to initialize: " << register_status.ToString() diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index 8e654c3bf..b82f8475d 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -17,6 +17,7 @@ cc_library( deps = [ ":constant_folding", ":qualified_reference_resolver", + ":resolver", "//base:status_macros", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", @@ -30,6 +31,7 @@ cc_library( "//eval/eval:jump_step", "//eval/eval:logic_step", "//eval/eval:select_step", + "//eval/eval:shadowable_value_step", "//eval/eval:ternary_step", "//eval/public:ast_traverse", "//eval/public:ast_visitor", @@ -147,6 +149,7 @@ cc_library( "qualified_reference_resolver.h", ], deps = [ + ":resolver", "//base:status_macros", "//eval/eval:const_value_step", "//eval/eval:expression_build_warning", @@ -162,6 +165,21 @@ cc_library( ], ) +cc_library( + name = "resolver", + srcs = ["resolver.cc"], + hdrs = ["resolver.h"], + deps = [ + "//eval/public:cel_builtins", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//eval/public:cel_value", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "qualified_reference_resolver_test", srcs = [ @@ -174,6 +192,7 @@ cc_test( "//eval/public:cel_builtins", "//eval/public:cel_function", "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", "//testutil:util", "@com_google_absl//absl/status", "@com_google_absl//absl/types:optional", @@ -203,3 +222,18 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_test( + name = "resolver_test", + size = "small", + srcs = ["resolver_test.cc"], + deps = [ + ":resolver", + "//eval/public:cel_function", + "//eval/public:cel_function_registry", + "//eval/public:cel_type_registry", + "//eval/testutil:test_message_cc_proto", + "@com_google_absl//absl/status", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 78928580d..3fd567396 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -11,6 +11,7 @@ #include "absl/strings/string_view.h" #include "eval/compiler/constant_folding.h" #include "eval/compiler/qualified_reference_resolver.h" +#include "eval/compiler/resolver.h" #include "eval/eval/comprehension_step.h" #include "eval/eval/const_value_step.h" #include "eval/eval/container_access_step.h" @@ -23,6 +24,7 @@ #include "eval/eval/jump_step.h" #include "eval/eval/logic_step.h" #include "eval/eval/select_step.h" +#include "eval/eval/shadowable_value_step.h" #include "eval/eval/ternary_step.h" #include "eval/public/ast_traverse.h" #include "eval/public/ast_visitor.h" @@ -149,59 +151,20 @@ class ComprehensionVisitor : public CondVisitor { class FlatExprVisitor : public AstVisitor { public: FlatExprVisitor( - const CelFunctionRegistry* function_registry, ExecutionPath* path, - bool short_circuiting, - const std::set& enums, - absl::string_view container, + const Resolver& resolver, ExecutionPath* path, bool short_circuiting, const absl::flat_hash_map& constant_idents, bool enable_comprehension, BuilderWarnings* warnings, std::set* iter_variable_names) - : flattened_path_(path), + : resolver_(resolver), + flattened_path_(path), progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), - function_registry_(function_registry), short_circuiting_(short_circuiting), constant_idents_(constant_idents), enable_comprehension_(enable_comprehension), builder_warnings_(warnings), iter_variable_names_(iter_variable_names) { GOOGLE_CHECK(iter_variable_names_); - - auto container_elements = absl::StrSplit(container, '.'); - - // Build list of prefixes from container. Non-empty prefixes must end with - // ".", otherwise prefix "abc.xy" will match "abc.xyz.EnumName". - std::string prefix = ""; - std::vector prefixes; - prefixes.push_back(prefix); - for (const auto& elem : container_elements) { - absl::StrAppend(&prefix, elem, "."); - prefixes.push_back(prefix); - } - - for (const auto& prefix : prefixes) { - for (auto enum_desc : enums) { - absl::string_view enum_name = enum_desc->full_name(); - if (!absl::StartsWith(enum_name, prefix)) { - continue; - } - - auto remainder = absl::StripPrefix(enum_name, prefix); - for (int i = 0; i < enum_desc->value_count(); i++) { - auto value_desc = enum_desc->value(i); - if (value_desc) { - // "prefixes" container is ascending-ordered. As such, we will be - // assigning enum reference to the deepest available. - // E.g. if both a.b.c.Name and a.b.Name are available, and - // we try to reference "Name" with the scope of "a.b.c", - // it will be resolved to "a.b.c.Name". - auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", - value_desc->name()); - enum_map_[key] = value_desc; - } - } - } - } } void PostVisitConst(const Constant* const_expr, const Expr* expr, @@ -227,7 +190,7 @@ class FlatExprVisitor : public AstVisitor { return; } - std::string path(ident_expr->name()); + const std::string& path = ident_expr->name(); // Automatically replace constant idents with the backing CEL values. auto constant = constant_idents_.find(path); @@ -236,41 +199,38 @@ class FlatExprVisitor : public AstVisitor { return; } - // Attempt to resolve the identifier as an enum. - const google::protobuf::EnumValueDescriptor* value_desc = nullptr; - auto it = enum_map_.find(path); - if (it != enum_map_.end()) { - value_desc = it->second; - } - - // Attempt to resolve a parsed-only expression as a namespaced identifier. + // Attempt to resolve a select expression as a namespaced identifier for an + // enum or type constant value. + absl::optional const_value = absl::nullopt; while (!namespace_stack_.empty()) { - const auto& select_node = namespace_stack_.back(); + const auto& select_node = namespace_stack_.front(); // Generate path in format ".....". - absl::StrAppend(&path, ".", select_node.second); - namespace_map_[select_node.first] = path; - - // Attempt to match namespace - auto it = enum_map_.find(path); - if (it != enum_map_.end()) { - resolved_select_expr_ = select_node.first; - value_desc = it->second; - } - namespace_stack_.pop_back(); - } - - if (resolved_select_expr_) { - if (!resolved_select_expr_->has_select_expr()) { - progress_status_ = absl::InternalError("Unexpected Expr type"); + auto select_expr = select_node.first; + auto qualified_path = absl::StrCat(path, ".", select_node.second); + namespace_map_[select_expr] = qualified_path; + + // Attempt to find a constant enum or type value which matches the + // qualified path present in the expression. Whether the identifier + // can be resolved to a type instance depends on whether the option to + // 'enable_qualified_type_identifiers' is set to true. + const_value = resolver_.FindConstant(qualified_path, select_expr->id()); + if (const_value.has_value()) { + AddStep(CreateShadowableValueStep(qualified_path, const_value.value(), + select_expr->id())); + resolved_select_expr_ = select_expr; + namespace_stack_.clear(); return; } - AddStep(CreateConstValueStep(value_desc, resolved_select_expr_->id())); - return; + namespace_stack_.pop_front(); } - if (value_desc) { - AddStep(CreateConstValueStep(value_desc, expr->id())); + + // Attempt to resolve a simple identifier as an enum or type constant value. + const_value = resolver_.FindConstant(path, expr->id()); + if (const_value.has_value()) { + AddStep(CreateShadowableValueStep(path, const_value.value(), expr->id())); return; } + AddStep(CreateIdentStep(ident_expr, expr->id())); } @@ -283,8 +243,24 @@ class FlatExprVisitor : public AstVisitor { // select_expr. // Chain of multiple SELECT ending with IDENT can represent namespaced // entity. - if (select_expr->operand().has_ident_expr() || - select_expr->operand().has_select_expr()) { + if (!select_expr->test_only() && + (select_expr->operand().has_ident_expr() || + select_expr->operand().has_select_expr())) { + // select expressions are pushed in reverse order: + // google.type.Expr is pushed as: + // - field: 'Expr' + // - field: 'type' + // - id: 'google' + // + // The search order though is as follows: + // - id: 'google.type.Expr' + // - id: 'google.type', field: 'Expr' + // - id: 'google', field: 'type', field: 'Expr' + for (int i = 0; i < namespace_stack_.size(); i++) { + auto ns = namespace_stack_[i]; + namespace_stack_[i] = { + ns.first, absl::StrCat(select_expr->field(), ".", ns.second)}; + } namespace_stack_.push_back({expr, select_expr->field()}); } else { namespace_stack_.clear(); @@ -311,7 +287,6 @@ class FlatExprVisitor : public AstVisitor { } std::string select_path = ""; - auto it = namespace_map_.find(expr); if (it != namespace_map_.end()) { select_path = it->second; @@ -364,16 +339,47 @@ class FlatExprVisitor : public AstVisitor { if (cond_visitor) { cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); - } else { - // Special case for "_[_]". - if (call_expr->function() == builtin::kIndex) { - AddStep(CreateContainerAccessStep(call_expr, expr->id())); + return; + } + + // Special case for "_[_]". + if (call_expr->function() == builtin::kIndex) { + AddStep(CreateContainerAccessStep(call_expr, expr->id())); + return; + } + + // Establish the search criteria for a given function. + absl::string_view function = call_expr->function(); + bool receiver_style = call_expr->has_target(); + size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); + auto arguments_matcher = ArgumentsMatcher(num_args); + + // First, search for lazily defined function overloads. + // Lazy functions shadow eager functions with the same signature. + auto lazy_overloads = resolver_.FindLazyOverloads( + function, receiver_style, arguments_matcher, expr->id()); + if (!lazy_overloads.empty()) { + AddStep(CreateFunctionStep(call_expr, expr->id(), lazy_overloads)); + return; + } + + // Second, search for eagerly defined function overloads. + auto overloads = resolver_.FindOverloads(function, receiver_style, + arguments_matcher, expr->id()); + if (overloads.empty()) { + // Create a warning that the overload could not be found. Depending on the + // builder_warnings configuration, this could result in termination of the + // CelExpression creation or an inspectable warning for use within runtime + // logging. + auto status = builder_warnings_->AddWarning( + absl::Status(absl::StatusCode::kInvalidArgument, + "No overloads provided for FunctionStep creation")); + if (!status.ok()) { + SetProgressStatusError(status); return; } - // For regular functions, just create one based on registry. - AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_, - builder_warnings_)); } + AddStep(CreateFunctionStep(call_expr, expr->id(), overloads)); } void PreVisitComprehension(const Comprehension*, const Expr* expr, @@ -445,7 +451,27 @@ class FlatExprVisitor : public AstVisitor { if (!progress_status_.ok()) { return; } - AddStep(CreateCreateStructStep(struct_expr, expr->id())); + + // If the message name is empty, this signals that a map should be created. + auto message_name = struct_expr->message_name(); + if (message_name.empty()) { + AddStep(CreateCreateStructStep(struct_expr, expr->id())); + return; + } + + // If the message name is not empty, then the message name must be resolved + // within the container, and if a descriptor is found, then a proto message + // creation step will be created. + auto message_desc = resolver_.FindDescriptor(message_name, expr->id()); + if (message_desc != nullptr) { + AddStep(CreateCreateStructStep(struct_expr, message_desc, expr->id())); + return; + } + + // Otherwise, the message descriptor was not linked into the binary. + SetProgressStatusError(absl::InvalidArgumentError( + "Error configuring message creation: message descriptor not found: " + + message_name)); } absl::Status progress_status() const { return progress_status_; } @@ -484,6 +510,7 @@ class FlatExprVisitor : public AstVisitor { } private: + const Resolver& resolver_; ExecutionPath* flattened_path_; absl::Status progress_status_; @@ -500,12 +527,6 @@ class FlatExprVisitor : public AstVisitor { // field is used as marker suppressing CelExpression creation for SELECTs. const Expr* resolved_select_expr_; - // Fully resolved enum value names. - absl::node_hash_map - enum_map_; - - const CelFunctionRegistry* function_registry_; - bool short_circuiting_; const absl::flat_hash_map& constant_idents_; @@ -724,6 +745,8 @@ FlatExprBuilder::CreateExpressionImpl( std::vector* warnings) const { ExecutionPath execution_path; BuilderWarnings warnings_builder(fail_on_warnings_); + Resolver resolver(container(), GetRegistry(), GetTypeRegistry(), + enable_qualified_type_identifiers_); if (absl::StartsWith(container(), ".") || absl::EndsWith(container(), ".")) { return absl::InvalidArgumentError( @@ -740,9 +763,8 @@ FlatExprBuilder::CreateExpressionImpl( // available, we can skip the reference resolve step here if it's already // done. if (reference_map != nullptr && !reference_map->empty()) { - absl::StatusOr> rewritten = - ResolveReferences(*effective_expr, *reference_map, *GetRegistry(), - container(), &warnings_builder); + absl::StatusOr> rewritten = ResolveReferences( + *effective_expr, *reference_map, resolver, &warnings_builder); if (!rewritten.ok()) { return rewritten.status(); } @@ -763,9 +785,8 @@ FlatExprBuilder::CreateExpressionImpl( } std::set iter_variable_names; - FlatExprVisitor visitor(this->GetRegistry(), &execution_path, - shortcircuiting_, resolvable_enums(), container(), - idents, enable_comprehension_, &warnings_builder, + FlatExprVisitor visitor(resolver, &execution_path, shortcircuiting_, idents, + enable_comprehension_, &warnings_builder, &iter_variable_names); AstTraverse(effective_expr, source_info, &visitor); diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 3b18ba9da..970c8a7ff 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -24,7 +24,8 @@ class FlatExprBuilder : public CelExpressionBuilder { constant_arena_(nullptr), enable_comprehension_(true), comprehension_max_iterations_(0), - fail_on_warnings_(true) {} + fail_on_warnings_(true), + enable_qualified_type_identifiers_(false) {} // set_enable_unknowns controls support for unknowns in expressions created. void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } @@ -65,6 +66,12 @@ class FlatExprBuilder : public CelExpressionBuilder { fail_on_warnings_ = should_fail; } + // set_enable_qualified_type_identifiers controls whether select expressions + // may be treated as constant type identifiers during CelExpression creation. + void set_enable_qualified_type_identifiers(bool enabled) { + enable_qualified_type_identifiers_ = enabled; + } + absl::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; @@ -98,6 +105,7 @@ class FlatExprBuilder : public CelExpressionBuilder { bool enable_comprehension_; int comprehension_max_iterations_; bool fail_on_warnings_; + bool enable_qualified_type_identifiers_; }; } // namespace runtime diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2a65d2bc3..0a1fcf258 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -146,7 +146,6 @@ TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { CelValue result = eval_status.value(); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(result.ErrorOrDie()->message(), Eq("No matching overloads found")); @@ -966,7 +965,7 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); auto build_status = builder.CreateExpression(&expr, &source_info); ASSERT_OK(build_status); @@ -997,7 +996,7 @@ TEST(FlatExprBuilderTest, SimpleEnumIdentTest) { cur_expr->mutable_ident_expr()->set_name(enum_name); FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); auto build_status = builder.CreateExpression(&expr, &source_info); ASSERT_OK(build_status); @@ -1068,8 +1067,8 @@ void EvalExpressionWithEnum(absl::string_view enum_name, cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); - builder.AddResolvableEnum(TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestMessage::TestEnum_descriptor()); + builder.GetTypeRegistry()->Register(TestEnum_descriptor()); builder.set_container(std::string(container)); auto build_status = builder.CreateExpression(&expr, &source_info); diff --git a/eval/compiler/qualified_reference_resolver.cc b/eval/compiler/qualified_reference_resolver.cc index 819238226..21f39c5af 100644 --- a/eval/compiler/qualified_reference_resolver.cc +++ b/eval/compiler/qualified_reference_resolver.cc @@ -55,81 +55,40 @@ absl::optional ToNamespace(const Expr& expr) { } } -// Shape matcher for CelFunctions. -// TODO(issues/91): this is the same behavior as parsed exprs in the CPP -// evaluator (just check the right call style and number of arguments), but we -// should have enough type information in a checked expr to find a more -// specific candidate list. -std::vector ArgumentMatcher(int argument_count) { - std::vector argument_matcher(argument_count); - for (int i = 0; i < argument_count; i++) { - argument_matcher[i] = CelValue::Type::kAny; - } - return argument_matcher; -} - -bool OverloadExists(const CelFunctionRegistry& registry, absl::string_view name, - const std::vector& argument_matcher, +bool OverloadExists(const Resolver& resolver, absl::string_view name, + const std::vector& arguments_matcher, bool receiver_style = false) { - return !registry.FindOverloads(name, receiver_style, argument_matcher) + return !resolver.FindOverloads(name, receiver_style, arguments_matcher) .empty() || - !registry.FindLazyOverloads(name, receiver_style, argument_matcher) + !resolver.FindLazyOverloads(name, receiver_style, arguments_matcher) .empty(); } -// Handles checking for the most specific (most qualified) function that matches -// the call shape. -class ContainerLookupHelper { - public: - ContainerLookupHelper(absl::string_view container, - const CelFunctionRegistry& registry) - : registry_(registry) { - auto container_elements = absl::StrSplit(container, '.'); - std::string prefix = ""; - namespace_prefixes_.push_back(prefix); - for (const auto& elem : container_elements) { - // Tolerate trailing / leading '.'. - if (elem.empty()) { - continue; - } - absl::StrAppend(&prefix, elem, "."); - namespace_prefixes_.insert(namespace_prefixes_.begin(), prefix); - } +// Return the qualified name of the most qualified matching overload, or +// nullopt if no matches are found. +absl::optional BestOverloadMatch(const Resolver& resolver, + absl::string_view base_name, + int argument_count) { + if (IsSpecialFunction(base_name)) { + return std::string(base_name); } - - // Return the qualified name of the most qualified matching overload, or - // nullopt if no matches are found. - absl::optional BestOverloadMatch(absl::string_view base_name, - int argument_count) { - if (IsSpecialFunction(base_name)) { - return std::string(base_name); + auto arguments_matcher = ArgumentsMatcher(argument_count); + // Check from most qualified to least qualified for a matching overload. + auto names = resolver.FullyQualifiedNames(base_name); + for (auto name = names.begin(); name != names.end(); ++name) { + if (OverloadExists(resolver, *name, arguments_matcher)) { + return *name; } - auto argument_matcher = ArgumentMatcher(argument_count); - // Check from most qualified to least qualified for a matching overload. - for (auto iter = namespace_prefixes_.begin(); - iter != namespace_prefixes_.end(); ++iter) { - std::string resolved_name = absl::StrCat(*iter, base_name); - if (OverloadExists(registry_, resolved_name, argument_matcher)) { - return resolved_name; - } - } - return absl::nullopt; } - - private: - // Namespace prefixes to check in most to least specific order. - std::vector namespace_prefixes_; - const CelFunctionRegistry& registry_; -}; + return absl::nullopt; +} class ReferenceResolver { public: ReferenceResolver(const google::protobuf::Map& reference_map, - const CelFunctionRegistry& registry, - BuilderWarnings* warnings, absl::string_view container) + const Resolver& resolver, BuilderWarnings* warnings) : reference_map_(reference_map), - container_lookup_helper_(container, registry), - registry_(registry), + resolver_(resolver), warnings_(warnings) {} // Attempt to resolve references in expr. Return true if part of the @@ -247,8 +206,7 @@ class ReferenceResolver { std::string resolved_name = absl::StrCat(maybe_namespace.value(), ".", call_expr->function()); auto maybe_resolved_function = - container_lookup_helper_.BestOverloadMatch(resolved_name, - arg_num); + BestOverloadMatch(resolver_, resolved_name, arg_num); if (maybe_resolved_function.has_value()) { call_expr->set_function(maybe_resolved_function.value()); call_expr->clear_target(); @@ -259,8 +217,8 @@ class ReferenceResolver { } else { // Not a receiver style function call. Check to see if it is a namespaced // function using a shorthand inside the expression container. - auto maybe_resolved_function = container_lookup_helper_.BestOverloadMatch( - call_expr->function(), arg_num); + auto maybe_resolved_function = + BestOverloadMatch(resolver_, call_expr->function(), arg_num); if (!maybe_resolved_function.has_value()) { RETURN_IF_ERROR(warnings_->AddWarning(absl::InvalidArgumentError( absl::StrCat("No overload found in reference resolve step for ", @@ -273,8 +231,8 @@ class ReferenceResolver { // For parity, if we didn't rewrite the receiver call style function, // check that an overload is provided in the builder. if (call_expr->has_target() && - !OverloadExists(registry_, call_expr->function(), - ArgumentMatcher(arg_num + 1), + !OverloadExists(resolver_, call_expr->function(), + ArgumentsMatcher(arg_num + 1), /* receiver_style= */ true)) { RETURN_IF_ERROR(warnings_->AddWarning(absl::InvalidArgumentError( absl::StrCat("No overload found in reference resolve step for ", @@ -353,8 +311,7 @@ class ReferenceResolver { } const google::protobuf::Map& reference_map_; - ContainerLookupHelper container_lookup_helper_; - const CelFunctionRegistry& registry_; + const Resolver& resolver_; BuilderWarnings* warnings_; }; @@ -362,11 +319,10 @@ class ReferenceResolver { absl::StatusOr> ResolveReferences( const Expr& expr, const google::protobuf::Map& reference_map, - const CelFunctionRegistry& registry, absl::string_view container, - BuilderWarnings* warnings) { + const Resolver& resolver, BuilderWarnings* warnings) { Expr out(expr); - ReferenceResolver resolver(reference_map, registry, warnings, container); - absl::StatusOr rewrite_result = resolver.Rewrite(&out); + ReferenceResolver ref_resolver(reference_map, resolver, warnings); + absl::StatusOr rewrite_result = ref_resolver.Rewrite(&out); if (!rewrite_result.ok()) { return rewrite_result.status(); } else if (rewrite_result.value()) { diff --git a/eval/compiler/qualified_reference_resolver.h b/eval/compiler/qualified_reference_resolver.h index fe8bd8b54..069653412 100644 --- a/eval/compiler/qualified_reference_resolver.h +++ b/eval/compiler/qualified_reference_resolver.h @@ -7,8 +7,8 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "eval/compiler/resolver.h" #include "eval/eval/expression_build_warning.h" -#include "eval/public/cel_function_registry.h" namespace google { namespace api { @@ -23,8 +23,7 @@ namespace runtime { absl::StatusOr> ResolveReferences( const google::api::expr::v1alpha1::Expr& expr, const google::protobuf::Map& reference_map, - const CelFunctionRegistry& registry, absl::string_view container, - BuilderWarnings* warnings); + const Resolver& resolver, BuilderWarnings* warnings); } // namespace runtime } // namespace expr diff --git a/eval/compiler/qualified_reference_resolver_test.cc b/eval/compiler/qualified_reference_resolver_test.cc index 695aef4e3..e36630b2d 100644 --- a/eval/compiler/qualified_reference_resolver_test.cc +++ b/eval/compiler/qualified_reference_resolver_test.cc @@ -9,6 +9,7 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" #include "testutil/util.h" #include "base/status_macros.h" @@ -84,8 +85,11 @@ TEST(ResolveReferences, Basic) { reference_map[2].set_name("foo.bar.var1"); reference_map[5].set_name("bar.foo.var2"); BuilderWarnings warnings; - CelFunctionRegistry registry; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -106,8 +110,11 @@ TEST(ResolveReferences, ReturnsNulloptIfNoChanges) { Expr expr = ParseTestProto(kExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); } @@ -116,11 +123,13 @@ TEST(ResolveReferences, NamespacedIdent) { Expr expr = ParseTestProto(kExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].set_name("foo.bar.var1"); reference_map[7].set_name("namespace_x.bar"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -168,9 +177,12 @@ TEST(ResolveReferences, WarningOnPresenceTest) { })"); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].set_name("foo.bar.var1"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT( @@ -210,13 +222,16 @@ constexpr char kEnumExpr[] = R"( TEST(ResolveReferences, EnumConstReferenceUsed) { Expr expr = ParseTestProto(kEnumExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].set_name("foo.bar.var1"); reference_map[5].set_name("bar.foo.Enum.ENUM_VAL1"); reference_map[5].mutable_value()->set_int64_value(9); BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -236,13 +251,16 @@ TEST(ResolveReferences, EnumConstReferenceUsed) { TEST(ResolveReferences, ConstReferenceSkipped) { Expr expr = ParseTestProto(kExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; - ASSERT_OK(RegisterBuiltinFunctions(®istry)); + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].set_name("foo.bar.var1"); reference_map[2].mutable_value()->set_bool_value(true); reference_map[5].set_name("bar.foo.var2"); BuilderWarnings warnings; - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -292,16 +310,19 @@ call_expr { TEST(ResolveReferences, FunctionReferenceBasic) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction( CelFunctionDescriptor("boolean_and", false, { CelValue::Type::kBool, CelValue::Type::kBool, }))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); } @@ -309,10 +330,13 @@ TEST(ResolveReferences, FunctionReferenceBasic) { TEST(ResolveReferences, FunctionReferenceMissingOverloadDetected) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), @@ -339,12 +363,14 @@ TEST(ResolveReferences, SpecialBuiltinsNotWarned) { for (const char* builtin_fn : special_builtins) { google::protobuf::Map reference_map; // Builtins aren't in the function registry. - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].add_overload_id(absl::StrCat("builtin.", builtin_fn)); expr.mutable_call_expr()->set_function(builtin_fn); - auto result = - ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), IsEmpty()); @@ -355,10 +381,13 @@ TEST(ResolveReferences, FunctionReferenceMissingOverloadDetectedAndMissingReference) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); BuilderWarnings warnings; reference_map[1].set_name("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT( @@ -374,9 +403,12 @@ TEST(ResolveReferences, FunctionReferenceToWrongExprKind) { Expr expr = ParseTestProto(kExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[2].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), @@ -405,11 +437,14 @@ TEST(ResolveReferences, FunctionReferenceWithTargetNoChange) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), IsEmpty()); @@ -420,9 +455,12 @@ TEST(ResolveReferences, Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), @@ -433,11 +471,14 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunction) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "ext.boolean_and", false, {CelValue::Type::kBool}))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -456,13 +497,15 @@ TEST(ResolveReferences, FunctionReferenceWithTargetToNamespacedFunctionInContainer) { Expr expr = ParseTestProto(kReceiverCallExtensionAndExpr); google::protobuf::Map reference_map; - BuilderWarnings warnings; - CelFunctionRegistry registry; reference_map[1].add_overload_id("udf_boolean_and"); - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + BuilderWarnings warnings; + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "com.google.ext.boolean_and", false, {CelValue::Type::kBool}))); - auto result = - ResolveReferences(expr, reference_map, registry, "com.google", &warnings); + CelTypeRegistry type_registry; + Resolver registry("com.google", &func_registry, &type_registry); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); EXPECT_THAT(result.value(), Optional(EqualsProto(R"( id: 1 @@ -507,19 +550,177 @@ TEST(ResolveReferences, FunctionReferenceWithHasTargetNoChange) { Expr expr = ParseTestProto(kReceiverCallHasExtensionAndExpr); google::protobuf::Map reference_map; BuilderWarnings warnings; - CelFunctionRegistry registry; - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + CelFunctionRegistry func_registry; + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "boolean_and", true, {CelValue::Type::kBool, CelValue::Type::kBool}))); - ASSERT_OK(registry.RegisterLazyFunction(CelFunctionDescriptor( + ASSERT_OK(func_registry.RegisterLazyFunction(CelFunctionDescriptor( "ext.option.boolean_and", true, {CelValue::Type::kBool}))); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); reference_map[1].add_overload_id("udf_boolean_and"); - auto result = ResolveReferences(expr, reference_map, registry, "", &warnings); + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); ASSERT_OK(result.status()); // The target is unchanged because it is a test_only select. EXPECT_THAT(result.value(), Eq(absl::nullopt)); EXPECT_THAT(warnings.warnings(), IsEmpty()); } +constexpr char kComprehensionExpr[] = R"( +id:17 +comprehension_expr: { + iter_var:"i" + iter_range:{ + id:1 + list_expr:{ + elements:{ + id:2 + const_expr:{int64_value:1} + } + elements:{ + id:3 + ident_expr:{name:"ENUM"} + } + elements:{ + id:4 + const_expr:{int64_value:3} + } + } + } + accu_var:"__result__" + accu_init: { + id:10 + const_expr:{bool_value:false} + } + loop_condition:{ + id:13 + call_expr:{ + function:"@not_strictly_false" + args:{ + id:12 + call_expr:{ + function:"!_" + args:{ + id:11 + ident_expr:{name:"__result__"} + } + } + } + } + } + loop_step:{ + id:15 + call_expr: { + function:"_||_" + args:{ + id:14 + ident_expr: {name:"__result__"} + } + args:{ + id:8 + call_expr:{ + function:"_==_" + args:{ + id:7 ident_expr:{name:"ENUM"} + } + args:{ + id:9 ident_expr:{name:"i"} + } + } + } + } + } + result:{id:16 ident_expr:{name:"__result__"}} +} +)"; +TEST(ResolveReferences, EnumConstReferenceUsedInComprehension) { + Expr expr = ParseTestProto(kComprehensionExpr); + google::protobuf::Map reference_map; + CelFunctionRegistry func_registry; + ASSERT_OK(RegisterBuiltinFunctions(&func_registry)); + CelTypeRegistry type_registry; + Resolver registry("", &func_registry, &type_registry); + reference_map[3].set_name("ENUM"); + reference_map[3].mutable_value()->set_int64_value(2); + reference_map[7].set_name("ENUM"); + reference_map[7].mutable_value()->set_int64_value(2); + BuilderWarnings warnings; + + auto result = ResolveReferences(expr, reference_map, registry, &warnings); + ASSERT_OK(result); + EXPECT_THAT(result.value(), Optional(EqualsProto(R"( + id: 17 + comprehension_expr { + iter_var: "i" + iter_range { + id: 1 + list_expr { + elements { + id: 2 + const_expr { int64_value: 1 } + } + elements { + id: 3 + const_expr { int64_value: 2 } + } + elements { + id: 4 + const_expr { int64_value: 3 } + } + } + } + accu_var: "__result__" + accu_init { + id: 10 + const_expr { bool_value: false } + } + loop_condition { + id: 13 + call_expr { + function: "@not_strictly_false" + args { + id: 12 + call_expr { + function: "!_" + args { + id: 11 + ident_expr { name: "__result__" } + } + } + } + } + } + loop_step { + id: 15 + call_expr { + function: "_||_" + args { + id: 14 + ident_expr { name: "__result__" } + } + args { + id: 8 + call_expr { + function: "_==_" + args { + id: 7 + const_expr { int64_value: 2 } + } + args { + id: 9 + ident_expr { name: "i" } + } + } + } + } + } + result { + id: 16 + ident_expr { name: "__result__" } + } + })"))); +} + } // namespace } // namespace runtime diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc new file mode 100644 index 000000000..1e9599843 --- /dev/null +++ b/eval/compiler/resolver.cc @@ -0,0 +1,164 @@ +#include "eval/compiler/resolver.h" + +#include "google/protobuf/descriptor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/types/optional.h" +#include "eval/public/cel_builtins.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +Resolver::Resolver(absl::string_view container, + const CelFunctionRegistry* function_registry, + const CelTypeRegistry* type_registry, + bool resolve_qualified_type_identifiers) + : namespace_prefixes_(), + enum_value_map_(), + function_registry_(function_registry), + type_registry_(type_registry), + resolve_qualified_type_identifiers_(resolve_qualified_type_identifiers) { + // The constructor for the registry determines the set of possible namespace + // prefixes which may appear within the given expression container, and also + // eagerly maps possible enum names to enum values. + + auto container_elements = absl::StrSplit(container, '.'); + std::string prefix = ""; + namespace_prefixes_.push_back(prefix); + for (const auto& elem : container_elements) { + // Tolerate trailing / leading '.'. + if (elem.empty()) { + continue; + } + absl::StrAppend(&prefix, elem, "."); + namespace_prefixes_.insert(namespace_prefixes_.begin(), prefix); + } + + for (const auto& prefix : namespace_prefixes_) { + for (auto enum_desc : type_registry->Enums()) { + absl::string_view enum_name = enum_desc->full_name(); + if (!absl::StartsWith(enum_name, prefix)) { + continue; + } + + auto remainder = absl::StripPrefix(enum_name, prefix); + for (int i = 0; i < enum_desc->value_count(); i++) { + auto value_desc = enum_desc->value(i); + if (value_desc) { + // "prefixes" container is ascending-ordered. As such, we will be + // assigning enum reference to the deepest available. + // E.g. if both a.b.c.Name and a.b.Name are available, and + // we try to reference "Name" with the scope of "a.b.c", + // it will be resolved to "a.b.c.Name". + auto key = absl::StrCat(remainder, !remainder.empty() ? "." : "", + value_desc->name()); + enum_value_map_[key] = CelValue::CreateInt64(value_desc->number()); + } + } + } + } +} + +std::vector Resolver::FullyQualifiedNames(absl::string_view name, + int64_t expr_id) const { + // TODO(issues/105): refactor the reference resolution into this method. + // and handle the case where this id is in the reference map as either a + // function name or identifier name. + std::vector names; + // Handle the case where the name contains a leading '.' indicating it is + // already fully-qualified. + if (absl::StartsWith(name, ".")) { + std::string fully_qualified_name = std::string(name.substr(1)); + names.push_back(fully_qualified_name); + return names; + } + + // namespace prefixes is guaranteed to contain at least empty string, so this + // function will always produce at least one result. + for (const auto& prefix : namespace_prefixes_) { + std::string fully_qualified_name = absl::StrCat(prefix, name); + names.push_back(fully_qualified_name); + } + return names; +} + +absl::optional Resolver::FindConstant(absl::string_view name, + int64_t expr_id) const { + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + // Attempt to resolve the fully qualified name to a known enum. + auto enum_entry = enum_value_map_.find(name); + if (enum_entry != enum_value_map_.end()) { + return enum_entry->second; + } + // Conditionally resolve fully qualified names as type values if the option + // to do so is configured in the expression builder. If the type name is + // not qualified, then it too may be returned as a constant value. + if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, '.')) { + auto type_value = type_registry_->FindType(name); + if (type_value.has_value()) { + return type_value.value(); + } + } + } + return absl::nullopt; +} + +std::vector Resolver::FindOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id) const { + // Resolve the fully qualified names and then search the function registry + // for possible matches. + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (auto it = names.begin(); it != names.end(); it++) { + // Only one set of overloads is returned along the namespace hierarchy as + // the function name resolution follows the same behavior as variable name + // resolution, meaning the most specific definition wins. This is different + // from how C++ namespaces work, as they will accumulate the overload set + // over the namespace hierarchy. + funcs = function_registry_->FindOverloads(*it, receiver_style, types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +std::vector Resolver::FindLazyOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id) const { + // Resolve the fully qualified names and then search the function registry + // for possible matches. + std::vector funcs; + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + funcs = function_registry_->FindLazyOverloads(name, receiver_style, types); + if (!funcs.empty()) { + return funcs; + } + } + return funcs; +} + +const google::protobuf::Descriptor* Resolver::FindDescriptor(absl::string_view name, + int64_t expr_id) const { + // Resolve the fully qualified names and then defer to the type registry + // for possible matches. + auto names = FullyQualifiedNames(name, expr_id); + for (const auto& name : names) { + auto desc = type_registry_->FindDescriptor(name); + if (desc != nullptr) { + return desc; + } + } + return nullptr; +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/compiler/resolver.h b/eval/compiler/resolver.h new file mode 100644 index 000000000..1f79867f1 --- /dev/null +++ b/eval/compiler/resolver.h @@ -0,0 +1,92 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Resolver assists with finding functions and types within a container. +// +// This class builds on top of the CelFunctionRegistry and CelTypeRegistry by +// layering on the namespace resolution rules of CEL onto the calls provided +// by each of these libraries. +// +// TODO(issues/105): refactor the Resolver to consider CheckedExpr metadata +// for reference resolution. +class Resolver { + public: + Resolver(absl::string_view container, + const CelFunctionRegistry* function_registry, + const CelTypeRegistry* type_registry, + bool resolve_qualified_type_identifiers = true); + + ~Resolver() {} + + // FindConstant will return an enum constant value or a type value if one + // exists for the given name. + // + // Since enums and type identifiers are specified as (potentially) qualified + // names within an expression, there is the chance that the name provided + // is a variable name which happens to collide with an existing enum or proto + // based type name. For this reason, within parsed only expressions, the + // constant should be treated as a value that can be shadowed by a runtime + // provided value. + absl::optional FindConstant(absl::string_view name, + int64_t expr_id) const; + + // FindDescriptor returns the protobuf message descriptor for the given name + // if one exists. + const google::protobuf::Descriptor* FindDescriptor(absl::string_view name, + int64_t expr_id) const; + + // FindLazyOverloads returns the set, possibly empty, of lazy overloads + // matching the given function signature. + std::vector FindLazyOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + + // FindOverloads returns the set, possibly empty, of eager function overloads + // matching the given function signature. + std::vector FindOverloads( + absl::string_view name, bool receiver_style, + const std::vector& types, int64_t expr_id = -1) const; + + // FullyQualifiedNames returns the set of fully qualified names which may be + // derived from the base_name within the specified expression container. + std::vector FullyQualifiedNames(absl::string_view base_name, + int64_t expr_id = -1) const; + + private: + std::vector namespace_prefixes_; + absl::flat_hash_map enum_value_map_; + const CelFunctionRegistry* function_registry_; + const CelTypeRegistry* type_registry_; + bool resolve_qualified_type_identifiers_; +}; + +// ArgumentMatcher generates a function signature matcher for CelFunctions. +// TODO(issues/91): this is the same behavior as parsed exprs in the CPP +// evaluator (just check the right call style and number of arguments), but we +// should have enough type information in a checked expr to find a more +// specific candidate list. +inline std::vector ArgumentsMatcher(int argument_count) { + std::vector argument_matcher(argument_count); + for (int i = 0; i < argument_count; i++) { + argument_matcher[i] = CelValue::Type::kAny; + } + return argument_matcher; +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_COMPILER_RESOLVER_H_ diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc new file mode 100644 index 000000000..980ebd4ef --- /dev/null +++ b/eval/compiler/resolver_test.cc @@ -0,0 +1,198 @@ +#include "eval/compiler/resolver.h" + +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" +#include "eval/testutil/test_message.pb.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using testing::Eq; + +class FakeFunction : public CelFunction { + public: + explicit FakeFunction(const std::string& name) + : CelFunction(CelFunctionDescriptor{name, false, {}}) {} + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + return absl::OkStatus(); + } +}; + +TEST(ResolverTest, TestFullyQualifiedNames) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr", &func_registry, &type_registry); + + auto names = resolver.FullyQualifiedNames("simple_name"); + std::vector expected_names( + {"google.api.expr.simple_name", "google.api.simple_name", + "google.simple_name", "simple_name"}); + EXPECT_THAT(names, Eq(expected_names)); +} + +TEST(ResolverTest, TestFullyQualifiedNamesPartiallyQualifiedName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr", &func_registry, &type_registry); + + auto names = resolver.FullyQualifiedNames("expr.simple_name"); + std::vector expected_names( + {"google.api.expr.expr.simple_name", "google.api.expr.simple_name", + "google.expr.simple_name", "expr.simple_name"}); + EXPECT_THAT(names, Eq(expected_names)); +} + +TEST(ResolverTest, TestFullyQualifiedNamesAbsoluteName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr", &func_registry, &type_registry); + + auto names = resolver.FullyQualifiedNames(".google.api.expr.absolute_name"); + EXPECT_THAT(names.size(), Eq(1)); + EXPECT_THAT(names[0], Eq("google.api.expr.absolute_name")); +} + +TEST(ResolverTest, TestFindConstantEnum) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + type_registry.Register(TestMessage::TestEnum_descriptor()); + Resolver resolver("google.api.expr.runtime.TestMessage", &func_registry, + &type_registry); + + auto enum_value = resolver.FindConstant("TestEnum.TEST_ENUM_1", -1); + EXPECT_TRUE(enum_value.has_value()); + EXPECT_TRUE(enum_value.value().IsInt64()); + EXPECT_THAT(enum_value.value().Int64OrDie(), Eq(1L)); + + enum_value = resolver.FindConstant( + ".google.api.expr.runtime.TestMessage.TestEnum.TEST_ENUM_2", -1); + EXPECT_TRUE(enum_value.has_value()); + EXPECT_TRUE(enum_value.value().IsInt64()); + EXPECT_THAT(enum_value.value().Int64OrDie(), Eq(2L)); +} + +TEST(ResolverTest, TestFindConstantUnqualifiedType) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto type_value = resolver.FindConstant("int", -1); + EXPECT_TRUE(type_value.has_value()); + EXPECT_TRUE(type_value.value().IsCelType()); + EXPECT_THAT(type_value.value().CelTypeOrDie().value(), Eq("int")); +} + +TEST(ResolverTest, TestFindConstantFullyQualifiedType) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto type_value = + resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); + EXPECT_TRUE(type_value.has_value()); + EXPECT_TRUE(type_value.value().IsCelType()); + EXPECT_THAT(type_value.value().CelTypeOrDie().value(), + Eq("google.api.expr.runtime.TestMessage")); +} + +TEST(ResolverTest, TestFindConstantQualifiedTypeDisabled) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("", &func_registry, &type_registry, false); + auto type_value = + resolver.FindConstant(".google.api.expr.runtime.TestMessage", -1); + EXPECT_FALSE(type_value.has_value()); +} + +TEST(ResolverTest, TestFindDescriptorBySimpleName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); + + auto desc_value = resolver.FindDescriptor("TestMessage", -1); + EXPECT_TRUE(desc_value != nullptr); + EXPECT_THAT(desc_value, Eq(TestMessage::GetDescriptor())); +} + +TEST(ResolverTest, TestFindDescriptorByQualifiedName) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); + + auto desc_value = + resolver.FindDescriptor(".google.api.expr.runtime.TestMessage", -1); + EXPECT_TRUE(desc_value != nullptr); + EXPECT_THAT(desc_value, Eq(TestMessage::GetDescriptor())); +} + +TEST(ResolverTest, TestFindDescriptorNotFound) { + CelFunctionRegistry func_registry; + CelTypeRegistry type_registry; + Resolver resolver("google.api.expr.runtime", &func_registry, &type_registry); + + auto desc_value = resolver.FindDescriptor("UndefinedMessage", -1); + EXPECT_TRUE(desc_value == nullptr); +} + +TEST(ResolverTest, TestFindOverloads) { + CelFunctionRegistry func_registry; + auto status = + func_registry.Register(std::make_unique("fake_func")); + ASSERT_OK(status); + status = func_registry.Register( + std::make_unique("cel.fake_ns_func")); + ASSERT_OK(status); + + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto overloads = + resolver.FindOverloads("fake_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + EXPECT_THAT(overloads[0]->descriptor().name(), Eq("fake_func")); + + overloads = + resolver.FindOverloads("fake_ns_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + EXPECT_THAT(overloads[0]->descriptor().name(), Eq("cel.fake_ns_func")); +} + +TEST(ResolverTest, TestFindLazyOverloads) { + CelFunctionRegistry func_registry; + auto status = func_registry.RegisterLazyFunction( + CelFunctionDescriptor{"fake_lazy_func", false, {}}); + ASSERT_OK(status); + status = func_registry.RegisterLazyFunction( + CelFunctionDescriptor{"cel.fake_lazy_ns_func", false, {}}); + ASSERT_OK(status); + + CelTypeRegistry type_registry; + Resolver resolver("cel", &func_registry, &type_registry); + + auto overloads = + resolver.FindLazyOverloads("fake_lazy_func", false, ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); + + overloads = resolver.FindLazyOverloads("fake_lazy_ns_func", false, + ArgumentsMatcher(0)); + EXPECT_THAT(overloads.size(), Eq(1)); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/BUILD b/eval/eval/BUILD index 9dcfafae4..48d2ea33b 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -60,13 +60,9 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", - "//eval/public:activation", - "//eval/public:cel_expression", "//eval/public:cel_value", "//eval/public/structs:cel_proto_wrapper", "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) @@ -131,7 +127,6 @@ cc_library( "//eval/public:cel_builtins", "//eval/public:cel_function", "//eval/public:cel_function_provider", - "//eval/public:cel_function_registry", "//eval/public:cel_value", "//eval/public:unknown_attribute_set", "//eval/public:unknown_function_result_set", @@ -206,6 +201,7 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -381,6 +377,7 @@ cc_test( "//base:status_macros", "//eval/public:cel_attribute", "//eval/public:cel_function", + "//eval/public:cel_function_registry", "//eval/public:cel_options", "//eval/public:cel_value", "//eval/public:unknown_function_result_set", @@ -461,11 +458,13 @@ cc_test( ":create_struct_step", ":ident_step", "//base:status_macros", + "//eval/public:cel_type_registry", "//eval/public/containers:container_backed_list_impl", "//eval/public/containers:container_backed_map_impl", "//eval/public/structs:cel_proto_wrapper", "//eval/testutil:test_message_cc_proto", "//testutil:util", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -604,3 +603,31 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "shadowable_value_step", + srcs = ["shadowable_value_step.cc"], + hdrs = ["shadowable_value_step.h"], + deps = [ + ":evaluator_core", + ":expression_step_base", + "//eval/public:activation", + "//eval/public:cel_value", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_test( + name = "shadowable_value_step_test", + size = "small", + srcs = ["shadowable_value_step_test.cc"], + deps = [ + ":evaluator_core", + ":shadowable_value_step", + "//base:status_macros", + "//eval/public:cel_value", + "@com_google_absl//absl/status:statusor", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 3a0a0fe08..05ae1a61b 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -12,7 +12,6 @@ namespace expr { namespace runtime { using google::api::expr::v1alpha1::Constant; -using google::api::expr::v1alpha1::Expr; namespace { diff --git a/eval/eval/const_value_step.h b/eval/eval/const_value_step.h index e3c685cbf..01a071295 100644 --- a/eval/eval/const_value_step.h +++ b/eval/eval/const_value_step.h @@ -3,7 +3,6 @@ #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" -#include "eval/public/activation.h" #include "eval/public/cel_value.h" namespace google { @@ -18,10 +17,6 @@ absl::optional ConvertConstant( absl::StatusOr> CreateConstValueStep( CelValue value, int64_t expr_id, bool comes_from_ast = true); -// Factory method for Constant(Enum value) - based Execution step -absl::StatusOr> CreateConstValueStep( - const google::protobuf::EnumValueDescriptor* value_descriptor, int64_t expr_id); - } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index 63d0621dd..68f1c727b 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -17,9 +17,7 @@ namespace runtime { namespace { -using ::google::protobuf::Arena; using ::google::protobuf::Descriptor; -using ::google::protobuf::DescriptorPool; using ::google::protobuf::FieldDescriptor; using ::google::protobuf::Message; using ::google::protobuf::MessageFactory; @@ -274,22 +272,12 @@ absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - int64_t expr_id) { - if (!create_struct_expr->message_name().empty()) { + const Descriptor* message_desc, int64_t expr_id) { + if (message_desc != nullptr) { // TODO(issues/92): Support resolving a type name within a container. // Make message-creating step. std::vector entries; - const Descriptor* desc = - DescriptorPool::generated_pool()->FindMessageTypeByName( - create_struct_expr.data()->message_name()); - - if (desc == nullptr) { - return absl::InvalidArgumentError( - "Error configuring message creation: message descriptor not found: " + - create_struct_expr->message_name()); - } - for (const auto& entry : create_struct_expr->entries()) { if (entry.field_key().empty()) { return absl::InvalidArgumentError( @@ -297,7 +285,7 @@ absl::StatusOr> CreateCreateStructStep( } const FieldDescriptor* field_desc = - desc->FindFieldByName(entry.field_key()); + message_desc->FindFieldByName(entry.field_key()); if (field_desc == nullptr) { return absl::InvalidArgumentError( "Error configuring message creation: field name not found"); @@ -305,7 +293,7 @@ absl::StatusOr> CreateCreateStructStep( entries.push_back({field_desc}); } - return std::make_unique(expr_id, desc, + return std::make_unique(expr_id, message_desc, std::move(entries)); } else { // Make map-creating step. diff --git a/eval/eval/create_struct_step.h b/eval/eval/create_struct_step.h index 1c5035192..e10e6d4ab 100644 --- a/eval/eval/create_struct_step.h +++ b/eval/eval/create_struct_step.h @@ -1,6 +1,10 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_CREATE_STRUCT_STEP_H_ +#include + +#include "google/protobuf/descriptor.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" @@ -10,10 +14,17 @@ namespace api { namespace expr { namespace runtime { -// Factory method for CreateList - based Execution step +// Factory method for CreateStruct - based Execution step absl::StatusOr> CreateCreateStructStep( const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, - int64_t expr_id); + const google::protobuf::Descriptor* message_desc, int64_t expr_id); + +inline absl::StatusOr> CreateCreateStructStep( + const google::api::expr::v1alpha1::Expr::CreateStruct* create_struct_expr, + int64_t expr_id) { + return CreateCreateStructStep(create_struct_expr, /*message_desc=*/nullptr, + expr_id); +} } // namespace runtime } // namespace expr diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index d7eac46d8..7586b1b0b 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -3,9 +3,11 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "eval/eval/ident_step.h" +#include "eval/public/cel_type_registry.h" #include "eval/public/containers/container_backed_list_impl.h" #include "eval/public/containers/container_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" @@ -38,6 +40,7 @@ absl::StatusOr RunExpression(absl::string_view field, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; + CelTypeRegistry type_registry; Expr expr0; Expr expr1; @@ -55,8 +58,12 @@ absl::StatusOr RunExpression(absl::string_view field, if (!step0_status.ok()) { return step0_status.status(); } - - auto step1_status = CreateCreateStructStep(create_struct, expr1.id()); + auto desc = type_registry.FindDescriptor(create_struct->message_name()); + if (desc == nullptr) { + return absl::Status(absl::StatusCode::kFailedPrecondition, + "missing proto message type"); + } + auto step1_status = CreateCreateStructStep(create_struct, desc, expr1.id()); if (!step1_status.ok()) { return step1_status.status(); @@ -113,7 +120,7 @@ void RunExpressionAndGetMessage(absl::string_view field, // Helper method. Creates simple pipeline containing CreateStruct step that // builds Map and runs it. absl::StatusOr RunCreateMapExpression( - const std::vector> values, + const std::vector>& values, google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; Activation activation; @@ -174,14 +181,16 @@ class CreateCreateStructStepTest : public testing::TestWithParam {}; TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; + CelTypeRegistry type_registry; Expr expr1; auto create_struct = expr1.mutable_struct_expr(); create_struct->set_message_name("google.api.expr.runtime.TestMessage"); + auto desc = type_registry.FindDescriptor(create_struct->message_name()); + ASSERT_TRUE(desc != nullptr); - auto step_status = CreateCreateStructStep(create_struct, expr1.id()); - + auto step_status = CreateCreateStructStep(create_struct, desc, expr1.id()); ASSERT_OK(step_status); path.push_back(std::move(step_status.value())); diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index 0ce103762..9d8977fc3 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -22,7 +22,6 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_provider.h" -#include "eval/public/cel_function_registry.h" #include "eval/public/cel_value.h" #include "eval/public/unknown_attribute_set.h" #include "eval/public/unknown_function_result_set.h" @@ -194,7 +193,7 @@ absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { class EagerFunctionStep : public AbstractFunctionStep { public: - EagerFunctionStep(std::vector&& overloads, + EagerFunctionStep(std::vector& overloads, const std::string& name, size_t num_args, int64_t expr_id) : AbstractFunctionStep(name, num_args, expr_id), overloads_(overloads) {} @@ -230,7 +229,7 @@ class LazyFunctionStep : public AbstractFunctionStep { // at runtime. LazyFunctionStep(const std::string& name, size_t num_args, bool receiver_style, - const std::vector& providers, + std::vector& providers, int64_t expr_id) : AbstractFunctionStep(name, num_args, expr_id), receiver_style_(receiver_style), @@ -281,33 +280,23 @@ absl::StatusOr LazyFunctionStep::ResolveFunction( absl::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id, - const CelFunctionRegistry& function_registry, - BuilderWarnings* builder_warnings) { + std::vector& lazy_overloads) { bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); const std::string& name = call_expr->function(); - std::vector args(num_args, CelValue::Type::kAny); + return absl::make_unique(name, num_args, receiver_style, + lazy_overloads, expr_id); +} - std::vector lazy_overloads = - function_registry.FindLazyOverloads(name, receiver_style, args); - - if (!lazy_overloads.empty()) { - return absl::make_unique(name, num_args, receiver_style, - lazy_overloads, expr_id); - } - - auto overloads = function_registry.FindOverloads(name, receiver_style, args); - - // No overloads found. - if (overloads.empty()) { - RETURN_IF_ERROR(builder_warnings->AddWarning( - absl::Status(absl::StatusCode::kInvalidArgument, - "No overloads provided for FunctionStep creation"))); - } - - return absl::make_unique(std::move(overloads), name, - num_args, expr_id); +absl::StatusOr> CreateFunctionStep( + const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id, + std::vector& overloads) { + bool receiver_style = call_expr->has_target(); + size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); + const std::string& name = call_expr->function(); + return absl::make_unique(overloads, name, num_args, + expr_id); } } // namespace runtime diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index 36b3da61e..6342d84ab 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -8,10 +8,9 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "absl/status/statusor.h" #include "eval/eval/evaluator_core.h" -#include "eval/eval/expression_build_warning.h" #include "eval/public/activation.h" #include "eval/public/cel_function.h" -#include "eval/public/cel_function_registry.h" +#include "eval/public/cel_function_provider.h" #include "eval/public/cel_value.h" namespace google { @@ -19,12 +18,18 @@ namespace api { namespace expr { namespace runtime { -// Factory method for Call - based Execution step -// Looks up function registry using data provided through Call parameter. +// Factory method for Call-based execution step where the function will be +// resolved at runtime (lazily) from an input Activation. absl::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id, - const CelFunctionRegistry& function_registry, - BuilderWarnings* builder_warnings); + std::vector& lazy_overloads); + +// Factory method for Call-based execution step where the function has been +// statically resolved from a set of eagerly functions configured in the +// CelFunctionRegistry. +absl::StatusOr> CreateFunctionStep( + const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id, + std::vector& overloads); } // namespace runtime } // namespace expr diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index 296ecdd04..a27f33a8a 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -1,5 +1,8 @@ #include "eval/eval/function_step.h" +#include +#include + #include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -10,6 +13,7 @@ #include "eval/eval/ident_step.h" #include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" +#include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" #include "eval/public/cel_value.h" #include "eval/public/structs/cel_proto_wrapper.h" @@ -171,6 +175,32 @@ void AddDefaults(CelFunctionRegistry& registry) { .ok()); } +std::vector ArgumentMatcher(int argument_count) { + std::vector argument_matcher(argument_count); + for (int i = 0; i < argument_count; i++) { + argument_matcher[i] = CelValue::Type::kAny; + } + return argument_matcher; +} + +std::vector ArgumentMatcher(const Expr::Call* call) { + return ArgumentMatcher(call->has_target() ? call->args_size() + 1 + : call->args_size()); +} + +absl::StatusOr> MakeTestFunctionStep( + const Expr::Call* call, const CelFunctionRegistry& registry) { + auto argument_matcher = ArgumentMatcher(call); + auto lazy_overloads = registry.FindLazyOverloads( + call->function(), call->has_target(), argument_matcher); + if (!lazy_overloads.empty()) { + return CreateFunctionStep(call, GetExprId(), lazy_overloads); + } + auto overloads = registry.FindOverloads(call->function(), call->has_target(), + argument_matcher); + return CreateFunctionStep(call, GetExprId(), overloads); +} + // Test common functions with varying levels of unknown support. class FunctionStepTest : public testing::TestWithParam { @@ -213,12 +243,9 @@ TEST_P(FunctionStepTest, SimpleFunctionTest) { Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -254,10 +281,8 @@ TEST_P(FunctionStepTest, TestStackUnderflow) { Expr::Call call1 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step2_status); @@ -274,50 +299,6 @@ TEST_P(FunctionStepTest, TestStackUnderflow) { EXPECT_FALSE(status.ok()); } -// Test that creation fails if fail on warnings is set in the warnings -// collection. -TEST(FunctionStepTest, TestNoOverloadsOnCreation) { - CelFunctionRegistry registry; - BuilderWarnings warnings(true); - - Expr::Call call = ConstFunction::MakeCall("Const0"); - - // function step with empty overloads - auto step0_status = - CreateFunctionStep(&call, GetExprId(), registry, &warnings); - - EXPECT_FALSE(step0_status.ok()); -} - -// Test that no overloads error is warned, actual error delayed to runtime by -// default. -TEST_P(FunctionStepTest, TestNoOverloadsOnCreationDelayedError) { - CelFunctionRegistry registry; - ExecutionPath path; - Expr::Call call = ConstFunction::MakeCall("Const0"); - BuilderWarnings warnings; - - // function step with empty overloads - auto step0_status = - CreateFunctionStep(&call, GetExprId(), registry, &warnings); - - EXPECT_TRUE(step0_status.ok()); - EXPECT_THAT(warnings.warnings(), testing::SizeIs(1)); - - path.push_back(std::move(step0_status.value())); - - std::unique_ptr impl = GetExpression(std::move(path)); - - Activation activation; - google::protobuf::Arena arena; - - auto status = impl->Evaluate(activation, &arena); - ASSERT_OK(status); - - auto value = status.value(); - ASSERT_TRUE(value.IsError()); -} - // Test situation when no overloads match input arguments during evaluation. TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { ExecutionPath path; @@ -336,12 +317,9 @@ TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { // Add expects {int64_t, int64_t} but it's {int64_t, uint64_t}. Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -368,8 +346,6 @@ TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; AddDefaults(registry); @@ -390,12 +366,9 @@ TEST_P(FunctionStepTest, Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -443,12 +416,9 @@ TEST_P(FunctionStepTest, LazyFunctionTest) { Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -479,7 +449,6 @@ TEST_P(FunctionStepTest, Activation activation; google::protobuf::Arena arena; CelFunctionRegistry registry; - BuilderWarnings warnings; AddDefaults(registry); @@ -506,12 +475,9 @@ TEST_P(FunctionStepTest, Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -564,18 +530,17 @@ class FunctionStepTestUnknowns unknown_functions = false; break; } - return absl::make_unique(&expr, std::move(path), 0, + return absl::make_unique(&expr_, std::move(path), 0, std::set(), true, unknown_functions); } private: - Expr expr; + Expr expr_; }; TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { ExecutionPath path; - BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -584,12 +549,9 @@ TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -626,8 +588,7 @@ TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { Expr::Call call1 = SinkFunction::MakeCall(); auto step0_status = CreateIdentStep(&ident1, GetExprId()); - auto step1_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = MakeTestFunctionStep(&call1, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -660,8 +621,6 @@ TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; AddDefaults(registry); @@ -677,12 +636,9 @@ TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -725,8 +681,6 @@ MATCHER_P2(IsAdd, a, b, "") { TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( @@ -740,12 +694,9 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { Expr::Call call2 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -777,8 +728,6 @@ TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( @@ -794,20 +743,13 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { Expr::Call call2 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step3_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step4_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step5_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step6_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); + auto step3_status = MakeTestFunctionStep(&call1, registry); + auto step4_status = MakeTestFunctionStep(&call2, registry); + auto step5_status = MakeTestFunctionStep(&add_call, registry); + auto step6_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -847,8 +789,6 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; ASSERT_OK(registry.Register( @@ -864,20 +804,13 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { Expr::Call call2 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step3_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step4_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step5_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - auto step6_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); + auto step3_status = MakeTestFunctionStep(&call2, registry); + auto step4_status = MakeTestFunctionStep(&call1, registry); + auto step5_status = MakeTestFunctionStep(&add_call, registry); + auto step6_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); @@ -907,7 +840,7 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { auto value = status.value(); - ASSERT_TRUE(value.IsUnknownSet()) << value.ErrorOrDie()->ToString(); + ASSERT_TRUE(value.IsUnknownSet()) << *value.ErrorOrDie(); // Arguments captured. EXPECT_THAT(value.UnknownSetOrDie() ->unknown_function_results() @@ -917,8 +850,6 @@ TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { ExecutionPath path; - BuilderWarnings warnings; - CelFunctionRegistry registry; CelError error0; @@ -937,12 +868,9 @@ TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = - CreateFunctionStep(&call1, GetExprId(), registry, &warnings); - auto step1_status = - CreateFunctionStep(&call2, GetExprId(), registry, &warnings); - auto step2_status = - CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step0_status = MakeTestFunctionStep(&call1, registry); + auto step1_status = MakeTestFunctionStep(&call2, registry); + auto step2_status = MakeTestFunctionStep(&add_call, registry); ASSERT_OK(step0_status); ASSERT_OK(step1_status); diff --git a/eval/eval/shadowable_value_step.cc b/eval/eval/shadowable_value_step.cc new file mode 100644 index 000000000..e9bd9aed7 --- /dev/null +++ b/eval/eval/shadowable_value_step.cc @@ -0,0 +1,45 @@ +#include "eval/eval/shadowable_value_step.h" + +#include "absl/status/statusor.h" +#include "eval/eval/expression_step_base.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +class ShadowableValueStep : public ExpressionStepBase { + public: + ShadowableValueStep(const std::string& identifier, const CelValue& value, + int64_t expr_id) + : ExpressionStepBase(expr_id), identifier_(identifier), value_(value) {} + + absl::Status Evaluate(ExecutionFrame* frame) const override; + + private: + std::string identifier_; + CelValue value_; +}; + +absl::Status ShadowableValueStep::Evaluate(ExecutionFrame* frame) const { + auto var = frame->activation().FindValue(identifier_, frame->arena()); + frame->value_stack().Push(var.value_or(value_)); + return absl::OkStatus(); +} + +} // namespace + +absl::StatusOr> CreateShadowableValueStep( + const std::string& identifier, const CelValue& value, int64_t expr_id) { + std::unique_ptr step = + absl::make_unique(identifier, value, expr_id); + return std::move(step); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/shadowable_value_step.h b/eval/eval/shadowable_value_step.h new file mode 100644 index 000000000..7671e159d --- /dev/null +++ b/eval/eval/shadowable_value_step.h @@ -0,0 +1,25 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ + +#include "absl/status/statusor.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/activation.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Create an identifier resolution step with a default value that may be +// shadowed by an identifier of the same name within the runtime-provided +// Activation. +absl::StatusOr> CreateShadowableValueStep( + const std::string& identifier, const CelValue& value, int64_t expr_id); + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_SHADOWABLE_VALUE_STEP_H_ diff --git a/eval/eval/shadowable_value_step_test.cc b/eval/eval/shadowable_value_step_test.cc new file mode 100644 index 000000000..15e25f8bc --- /dev/null +++ b/eval/eval/shadowable_value_step_test.cc @@ -0,0 +1,79 @@ +#include "eval/eval/shadowable_value_step.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/statusor.h" +#include "eval/eval/evaluator_core.h" +#include "eval/public/cel_value.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::protobuf::Arena; +using testing::Eq; + +absl::StatusOr RunShadowableExpression(const std::string& identifier, + const CelValue& value, + const Activation& activation, + Arena* arena) { + auto step_status = CreateShadowableValueStep(identifier, value, 1); + if (!step_status.ok()) { + return step_status.status(); + } + + ExecutionPath path; + path.push_back(std::move(step_status.value())); + + google::api::expr::v1alpha1::Expr dummy_expr; + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}); + return impl.Evaluate(activation, arena); +} + +TEST(ShadowableValueStepTest, TestEvaluateNoShadowing) { + std::string type_name = "google.api.expr.runtime.TestMessage"; + + Activation activation; + Arena arena; + + auto type_value = + CelValue::CreateCelType(CelValue::CelTypeHolder(&type_name)); + auto status = + RunShadowableExpression(type_name, type_value, activation, &arena); + ASSERT_OK(status); + + auto value = status.value(); + ASSERT_TRUE(value.IsCelType()); + EXPECT_THAT(value.CelTypeOrDie().value(), Eq(type_name)); +} + +TEST(ShadowableValueStepTest, TestEvaluateShadowedIdentifier) { + std::string type_name = "int"; + auto shadow_value = CelValue::CreateInt64(1024L); + + Activation activation; + activation.InsertValue(type_name, shadow_value); + Arena arena; + + auto type_value = + CelValue::CreateCelType(CelValue::CelTypeHolder(&type_name)); + auto status = + RunShadowableExpression(type_name, type_value, activation, &arena); + ASSERT_OK(status); + + auto value = status.value(); + ASSERT_TRUE(value.IsInt64()); + EXPECT_THAT(value.Int64OrDie(), Eq(1024L)); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/BUILD b/eval/public/BUILD index 2ec095c48..3722053af 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -216,6 +216,7 @@ cc_library( ":activation", ":cel_function", ":cel_function_registry", + ":cel_type_registry", ":cel_value", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -336,7 +337,6 @@ cc_test( ":cel_value", ":unknown_attribute_set", ":unknown_set", - "//base:status_macros", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", @@ -449,6 +449,34 @@ cc_test( ], ) +cc_library( + name = "cel_type_registry", + srcs = ["cel_type_registry.cc"], + hdrs = ["cel_type_registry.h"], + deps = [ + ":cel_value", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "cel_type_registry_test", + srcs = ["cel_type_registry_test.cc"], + deps = [ + ":cel_type_registry", + "//eval/testutil:test_message_cc_proto", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_googletest//:gtest_main", + ], +) + cc_test( name = "builtin_func_test", size = "small", diff --git a/eval/public/activation.h b/eval/public/activation.h index 1c7271b9f..8e973140f 100644 --- a/eval/public/activation.h +++ b/eval/public/activation.h @@ -84,8 +84,7 @@ class Activation : public BaseActivation { google::protobuf::Arena* arena) const override; bool IsPathUnknown(absl::string_view path) const override { - return google::protobuf::util::FieldMaskUtil::IsPathInFieldMask(std.data()::string(path), - unknown_paths_); + return google::protobuf::util::FieldMaskUtil::IsPathInFieldMask(path.data(), unknown_paths_); } // Insert a function into the activation (ie a lazily bound function). Returns diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 5e23df1ca..b8cd42dd0 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -18,6 +18,8 @@ std::unique_ptr CreateCelExpressionBuilder( builder->set_comprehension_max_iterations( options.comprehension_max_iterations); builder->set_fail_on_warnings(options.fail_on_warnings); + builder->set_enable_qualified_type_identifiers( + options.enable_qualified_type_identifiers); switch (options.unknown_processing) { case UnknownProcessingOptions::kAttributeAndFunction: diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 79037630b..8a5372bca 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -2,6 +2,7 @@ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_EXPRESSION_H_ #include +#include #include "google/api/expr/v1alpha1/checked.pb.h" #include "google/api/expr/v1alpha1/syntax.pb.h" @@ -10,6 +11,7 @@ #include "eval/public/activation.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_type_registry.h" #include "eval/public/cel_value.h" namespace google { @@ -76,7 +78,9 @@ class CelExpression { class CelExpressionBuilder { public: CelExpressionBuilder() - : registry_(absl::make_unique()), container_("") {} + : func_registry_(absl::make_unique()), + type_registry_(absl::make_unique()), + container_("") {} virtual ~CelExpressionBuilder() {} @@ -118,25 +122,17 @@ class CelExpressionBuilder { // CelFunction registry. Extension function should be registered with it // prior to expression creation. - CelFunctionRegistry* GetRegistry() const { return registry_.get(); } - - // Enums registered with the builder. - // - // TODO(issues/105): this should not be std::set as the ordering of pointers - // is inconsistent across processes and should be absl::node_hash_map or - // absl::flat_hash_map - const std::set& resolvable_enums() const { - return resolvable_enums_; - } + CelFunctionRegistry* GetRegistry() const { return func_registry_.get(); } - // Add Enum to the list of resolvable by the builder. - void AddResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { - resolvable_enums_.emplace(enum_descriptor); - } + // CEL Type registry. Provides a means to resolve the CEL built-in types to + // CelValue instances, and to extend the set of types and enums known to + // expressions by registering them ahead of time. + CelTypeRegistry* GetTypeRegistry() const { return type_registry_.get(); } - // Remove Enum from the list of resolvable by the builder. - void RemoveResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { - resolvable_enums_.erase(enum_descriptor); + // Add Enum to the list of resolvable by the builder. + void ABSL_DEPRECATED("Use GetTypeRegistry()->Register() instead") + AddResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { + type_registry_->Register(enum_descriptor); } void set_container(std::string container) { @@ -146,8 +142,8 @@ class CelExpressionBuilder { absl::string_view container() const { return container_; } private: - std::unique_ptr registry_; - std::set resolvable_enums_; + std::unique_ptr func_registry_; + std::unique_ptr type_registry_; std::string container_; }; diff --git a/eval/public/cel_function_registry.cc b/eval/public/cel_function_registry.cc index 34202afe4..04755b490 100644 --- a/eval/public/cel_function_registry.cc +++ b/eval/public/cel_function_registry.cc @@ -41,7 +41,7 @@ std::vector CelFunctionRegistry::FindOverloads( const std::vector& types) const { std::vector matched_funcs; - auto overloads = functions_.find(std::string(name)); + auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } @@ -60,7 +60,7 @@ std::vector CelFunctionRegistry::FindLazyOverloads( const std::vector& types) const { std::vector matched_funcs; - auto overloads = functions_.find(std::string(name)); + auto overloads = functions_.find(name); if (overloads == functions_.end()) { return matched_funcs; } diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index de10230cb..dcd774c6c 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -72,6 +72,14 @@ struct InterpreterOptions { // Treat builder warnings as fatal errors. bool fail_on_warnings = true; + + // Enable the resolution of qualified type identifiers as type values instead + // of field selections. + // + // This toggle may cause certain identifiers which overlap with CEL built-in + // type or with protobuf message types linked into the binary to be resolved + // as static type values rather than as per-eval variables. + bool enable_qualified_type_identifiers = false; }; } // namespace runtime diff --git a/eval/public/cel_type_registry.cc b/eval/public/cel_type_registry.cc new file mode 100644 index 000000000..565ae3e27 --- /dev/null +++ b/eval/public/cel_type_registry.cc @@ -0,0 +1,75 @@ +#include "eval/public/cel_type_registry.h" + +#include "google/protobuf/descriptor.h" +#include "absl/container/node_hash_set.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +const absl::node_hash_set& GetCoreTypes() { + static const auto* const kCoreTypes = + new absl::node_hash_set{{"bool"}, + {"bytes"}, + {"double"}, + {"google.protobuf.Duration"}, + {"google.protobuf.Timestamp"}, + {"int"}, + {"list"}, + {"map"}, + {"null_type"}, + {"string"}, + {"type"}, + {"uint"}}; + return *kCoreTypes; +} + +} // namespace + +CelTypeRegistry::CelTypeRegistry() : types_(GetCoreTypes()), enums_() {} + +void CelTypeRegistry::Register(std::string fully_qualified_type_name) { + // Registers the fully qualified type name as a CEL type. + types_.insert(std::move(fully_qualified_type_name)); +} + +void CelTypeRegistry::Register(const google::protobuf::EnumDescriptor* enum_descriptor) { + enums_.insert(enum_descriptor); +} + +const google::protobuf::Descriptor* CelTypeRegistry::FindDescriptor( + absl::string_view fully_qualified_type_name) const { + return google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( + fully_qualified_type_name.data()); +} + +absl::optional CelTypeRegistry::FindType( + absl::string_view fully_qualified_type_name) const { + // Searches through explicitly registered type names first. + auto type = types_.find(fully_qualified_type_name); + // The CelValue returned by this call will remain valid as long as the + // CelExpression and associated builder stay in scope. + if (type != types_.end()) { + return CelValue::CreateCelTypeView(*type); + } + + // By default falls back to looking at whether the protobuf descriptor is + // linked into the binary. In the future, this functionality may be disabled, + // but this is most consistent with the current CEL C++ behavior. + auto desc = FindDescriptor(fully_qualified_type_name); + if (desc != nullptr) { + return CelValue::CreateCelTypeView(desc->full_name()); + } + return absl::nullopt; +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/cel_type_registry.h b/eval/public/cel_type_registry.h new file mode 100644 index 000000000..f8d236895 --- /dev/null +++ b/eval/public/cel_type_registry.h @@ -0,0 +1,74 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ + +#include "google/protobuf/descriptor.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// CelTypeRegistry manages the set of registered types available for use within +// object literal construction, enum comparisons, and type testing. +// +// The CelTypeRegistry is intended to live for the duration of all CelExpression +// values created by a given CelExpressionBuilder and one is created by default +// within the standard CelExpressionBuilder. +// +// By default, all core CEL types and all linked protobuf message types are +// implicitly registered by way of the generated descriptor pool. In the future, +// such type registrations may be explicit to avoid accidentally exposing linked +// protobuf types to CEL which were intended to remain internal. +class CelTypeRegistry { + public: + CelTypeRegistry(); + + ~CelTypeRegistry() {} + + // Register a fully qualified type name as a valid type for use within CEL + // expressions. + // + // This call establishes a CelValue type instance that can be used in runtime + // comparisons, and may have implications in the future about which protobuf + // message types linked into the binary may also be used by CEL. + // + // Type registration must be performed prior to CelExpression creation. + void Register(std::string fully_qualified_type_name); + + // Register an enum whose values may be used within CEL expressions. + // + // Enum registration must be performed prior to CelExpression creation. + void Register(const google::protobuf::EnumDescriptor* enum_descriptor); + + // Find a protobuf Descriptor given a fully qualified protobuf type name. + const google::protobuf::Descriptor* FindDescriptor( + absl::string_view fully_qualified_type_name) const; + + // Find a type's CelValue instance by its fully qualified name. + absl::optional FindType( + absl::string_view fully_qualified_type_name) const; + + // Return the set of enums configured within the type registry. + inline const absl::flat_hash_set& Enums() + const { + return enums_; + } + + private: + // pointer-stability is required for the strings in the types set, which is + // why a node_hash_set is used instead of another container type. + absl::node_hash_set types_; + absl::flat_hash_set enums_; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_TYPE_REGISTRY_H_ diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc new file mode 100644 index 000000000..d90cce2ca --- /dev/null +++ b/eval/public/cel_type_registry_test.cc @@ -0,0 +1,86 @@ +#include "eval/public/cel_type_registry.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/container/flat_hash_map.h" +#include "eval/testutil/test_message.pb.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using testing::Eq; + +TEST(CelTypeRegistryTest, TestRegisterEnumDescriptor) { + CelTypeRegistry registry; + registry.Register(TestMessage::TestEnum_descriptor()); + + absl::flat_hash_set enum_set; + for (auto enum_desc : registry.Enums()) { + enum_set.insert(enum_desc->full_name()); + } + absl::flat_hash_set expected_set; + expected_set.insert({"google.api.expr.runtime.TestMessage.TestEnum"}); + EXPECT_THAT(enum_set, Eq(expected_set)); +} + +TEST(CelTypeRegistryTest, TestRegisterTypeName) { + CelTypeRegistry registry; + + // Register the type, scoping the type name lifecycle to the nested block. + { + std::string custom_type = "custom_type"; + registry.Register(custom_type); + } + + auto type = registry.FindType("custom_type"); + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type.value().IsCelType()); + EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("custom_type")); +} + +TEST(CelTypeRegistryTest, TestFindDescriptorFound) { + CelTypeRegistry registry; + auto desc = registry.FindDescriptor("google.api.expr.Expr"); + ASSERT_TRUE(desc != nullptr); + EXPECT_THAT(desc->full_name(), Eq("google.api.expr.Expr")); +} + +TEST(CelTypeRegistryTest, TestFindDescriptorNotFound) { + CelTypeRegistry registry; + auto desc = registry.FindDescriptor("missing.MessageType"); + EXPECT_TRUE(desc == nullptr); +} + +TEST(CelTypeRegistryTest, TestFindTypeCoreTypeFound) { + CelTypeRegistry registry; + auto type = registry.FindType("int"); + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type.value().IsCelType()); + EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("int")); +} + +TEST(CelTypeRegistryTest, TestFindTypeProtobufTypeFound) { + CelTypeRegistry registry; + auto type = registry.FindType("google.api.expr.Expr"); + ASSERT_TRUE(type.has_value()); + EXPECT_TRUE(type.value().IsCelType()); + EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("google.api.expr.Expr")); +} + +TEST(CelTypeRegistryTest, TestFindTypeNotRegisteredTypeNotFound) { + CelTypeRegistry registry; + auto type = registry.FindType("missing.MessageType"); + EXPECT_FALSE(type.has_value()); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 76be6f511..73bbf2623 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -205,6 +205,14 @@ class CelValue { return CelValue(holder); } + static CelValue CreateCelTypeView(absl::string_view value) { + // This factory method is used for dealing with string references which + // come from protobuf objects or other containers which promise pointer + // stability. In general, this is a risky method to use and should not + // be invoked outside the core CEL library. + return CelValue(CelTypeHolder(value)); + } + static CelValue CreateError(const CelError *value) { CheckNullPointer(value, Type::kError); return CelValue(value); diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index f9538965e..dee3d667e 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -15,7 +15,6 @@ namespace expr { namespace runtime { using testing::Eq; -using testing::UnorderedPointwise; class DummyMap : public CelMap { public: @@ -257,6 +256,13 @@ TEST(CelValueTest, TestCelType) { EXPECT_THAT(value_bytes.type(), Eq(CelValue::Type::kBytes)); EXPECT_THAT(value_bytes.ObtainCelType().CelTypeOrDie().value(), Eq("bytes")); + std::string msg_type_str = "google.api.expr.runtime.TestMessage"; + CelValue msg_type = CelValue::CreateCelTypeView(msg_type_str); + EXPECT_TRUE(msg_type.IsCelType()); + EXPECT_THAT(msg_type.CelTypeOrDie().value(), + Eq("google.api.expr.runtime.TestMessage")); + EXPECT_THAT(msg_type.type(), Eq(CelValue::Type::kCelType)); + UnknownSet unknown_set; CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); From e53c05c61b201f683fe4bec80f6cbd0f99bdcf32 Mon Sep 17 00:00:00 2001 From: kuat Date: Mon, 25 Jan 2021 17:53:29 -0500 Subject: [PATCH 5/6] Enforce utf-8 validity in string() conversion from bytes. Relies on SIMD-optimized library https://github.com/cyb70289/utf8. PiperOrigin-RevId: 353736473 --- conformance/BUILD | 2 -- eval/public/BUILD | 1 + eval/public/builtin_func_registrar.cc | 18 +++++++++--------- eval/public/builtin_func_test.cc | 9 +++++++++ eval/public/cel_value.h | 3 +++ testutil/BUILD | 3 +++ 6 files changed, 25 insertions(+), 11 deletions(-) diff --git a/conformance/BUILD b/conformance/BUILD index b538e25d4..b357cef48 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -90,8 +90,6 @@ cc_binary( # uncommented when the spec changes to truncation rather than rounding. "--skip_test=conversions/int/double_nearest,double_nearest_neg,double_half_away_neg,double_half_away_pos", "--skip_test=conversions/uint/double_nearest,double_nearest_int,double_half_away", - # TODO(issues/82): Unexpected behavior when converting invalid bytes to string. - "--skip_test=conversions/string/bytes_invalid", # TODO(issues/96): Well-known type conversion support. "--skip_test=proto2/literal_wellknown", "--skip_test=proto3/literal_wellknown", diff --git a/eval/public/BUILD b/eval/public/BUILD index 3722053af..4bd638717 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -180,6 +180,7 @@ cc_library( ":cel_function_adapter", ":cel_function_registry", ":cel_options", + "//base:unilib", "//eval/public/containers:container_backed_list_impl", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 34554d5d4..86fa2adbe 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -16,6 +16,7 @@ #include "eval/public/cel_options.h" #include "eval/public/containers/container_backed_list_impl.h" #include "re2/re2.h" +#include "base/unilib.h" namespace google { namespace api { @@ -1145,16 +1146,15 @@ absl::Status RegisterStringConversionFunctions( return absl::OkStatus(); } - // TODO(issues/82): ensure the bytes conversion to string handles UTF-8 - // properly, and avoids unncessary allocations. - // bytes -> string - auto status = FunctionAdapter:: - CreateAndRegister( + auto status = + FunctionAdapter::CreateAndRegister( builtin::kString, false, - [](Arena* arena, - CelValue::BytesHolder value) -> CelValue::StringHolder { - return CelValue::StringHolder( - Arena::Create(arena, std::string(value.value()))); + [](Arena* arena, CelValue::BytesHolder value) -> CelValue { + if (UniLib::IsStructurallyValid(value.value())) { + return CelValue::CreateStringView(value.value()); + } + return CreateErrorValue(arena, "invalid UTF-8 bytes value", + absl::StatusCode::kInvalidArgument); }, registry); if (!status.ok()) return status; diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 5b3f7b267..008f3bf08 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -1686,6 +1686,15 @@ TEST_F(BuiltinsTest, BytesToString) { EXPECT_EQ(result_value.StringOrDie().value(), "abcd"); } +TEST_F(BuiltinsTest, BytesToStringInvalid) { + std::string input = "\xFF"; + std::vector args = {CelValue::CreateBytes(&input)}; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE( + PerformRun(builtin::kString, {}, args, &result_value)); + ASSERT_TRUE(result_value.IsError()); +} + TEST_F(BuiltinsTest, StringToString) { std::string input = "abcd"; std::vector args = {CelValue::CreateString(&input)}; diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 73bbf2623..a7c34bdce 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -162,6 +162,9 @@ class CelValue { static CelValue CreateString(StringHolder holder) { return CelValue(holder); } + // Returns a string value from a string_view. Warning: the caller is + // responsible for the lifecycle of the backing string. Prefer CreateString + // instead. static CelValue CreateStringView(absl::string_view value) { return CelValue(StringHolder(value)); } diff --git a/testutil/BUILD b/testutil/BUILD index 450474c48..eb71cff9b 100644 --- a/testutil/BUILD +++ b/testutil/BUILD @@ -49,6 +49,7 @@ cc_library( cc_library( name = "test_data_io", + testonly = True, srcs = [ "test_data_io.cc", ], @@ -76,6 +77,7 @@ cc_library( # third_party/cel/spec/testdata/unique_values.textpb cc_binary( name = "test_data_gen", + testonly = True, srcs = [ "test_data_gen.cc", ], @@ -110,6 +112,7 @@ cc_test( cc_library( name = "util", + testonly = True, hdrs = [ "util.h", ], From f029e3e1ffe3a8b539f12ecde89c38d879b72e6d Mon Sep 17 00:00:00 2001 From: kuat Date: Thu, 28 Jan 2021 21:06:06 -0500 Subject: [PATCH 6/6] OSS export. PiperOrigin-RevId: 354438896 --- base/BUILD | 14 +++++++++++++ base/unilib.cc | 19 ++++++++++++++++++ base/unilib.h | 29 +++++++++++++++++++++++++++ common/type.cc | 2 +- eval/compiler/BUILD | 1 + eval/compiler/flat_expr_builder.cc | 2 +- eval/compiler/resolver.cc | 2 +- eval/compiler/resolver_test.cc | 1 + eval/public/BUILD | 2 +- eval/public/cel_type_registry_test.cc | 10 ++++----- 10 files changed, 73 insertions(+), 9 deletions(-) create mode 100644 base/unilib.cc create mode 100644 base/unilib.h diff --git a/base/BUILD b/base/BUILD index 61633b38f..a22e7f742 100644 --- a/base/BUILD +++ b/base/BUILD @@ -10,3 +10,17 @@ cc_library( "status_macros.h", ], ) + +cc_library( + name = "unilib", + srcs = [ + "unilib.cc", + ], + hdrs = [ + "unilib.h", + ], + deps = [ + "@com_github_google_flatbuffers//:flatbuffers", + "@com_google_absl//absl/strings", + ], +) diff --git a/base/unilib.cc b/base/unilib.cc new file mode 100644 index 000000000..6355af23e --- /dev/null +++ b/base/unilib.cc @@ -0,0 +1,19 @@ +#include "base/unilib.h" + +#include "flatbuffers/util.h" + +namespace UniLib { + +// Detects whether a string is valid UTF-8. +bool IsStructurallyValid(absl::string_view str) { + const char *s = &str[0]; + const char *const sEnd = s + str.length(); + while (s < sEnd) { + if (flatbuffers::FromUTF8(&s) < 0) { + return false; + } + } + return true; +} + +} // namespace UniLib diff --git a/base/unilib.h b/base/unilib.h new file mode 100644 index 000000000..3eb1e4958 --- /dev/null +++ b/base/unilib.h @@ -0,0 +1,29 @@ +/* + * Copyright 2021 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_BASE_UNILIB_H_ +#define THIRD_PARTY_CEL_CPP_BASE_UNILIB_H_ + +#include "absl/strings/string_view.h" + +namespace UniLib { + +// Detects whether a string is valid UTF-8. +bool IsStructurallyValid(absl::string_view str); + +} // namespace UniLib + +#endif // THIRD_PARTY_CEL_CPP_BASE_UNILIB_H_ diff --git a/common/type.cc b/common/type.cc index fb8097f5c..81ff9b1e6 100644 --- a/common/type.cc +++ b/common/type.cc @@ -72,7 +72,7 @@ UnrecognizedType::UnrecognizedType(absl::string_view full_name) : string_rep_(absl::StrCat("type(\"", full_name, "\")")), hash_code_(internal::Hash(full_name)) { assert(google::protobuf::DescriptorPool::generated_pool()->FindMessageTypeByName( - std.data()::string(full_name)) == nullptr); + full_name.data()) == nullptr); } absl::string_view UnrecognizedType::full_name() const { diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index b82f8475d..8de2fd32b 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -229,6 +229,7 @@ cc_test( srcs = ["resolver_test.cc"], deps = [ ":resolver", + "//base:status_macros", "//eval/public:cel_function", "//eval/public:cel_function_registry", "//eval/public:cel_type_registry", diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index 3fd567396..4e52d76a6 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -256,7 +256,7 @@ class FlatExprVisitor : public AstVisitor { // - id: 'google.type.Expr' // - id: 'google.type', field: 'Expr' // - id: 'google', field: 'type', field: 'Expr' - for (int i = 0; i < namespace_stack_.size(); i++) { + for (size_t i = 0; i < namespace_stack_.size(); i++) { auto ns = namespace_stack_[i]; namespace_stack_[i] = { ns.first, absl::StrCat(select_expr->field(), ".", ns.second)}; diff --git a/eval/compiler/resolver.cc b/eval/compiler/resolver.cc index 1e9599843..0a6856e7f 100644 --- a/eval/compiler/resolver.cc +++ b/eval/compiler/resolver.cc @@ -97,7 +97,7 @@ absl::optional Resolver::FindConstant(absl::string_view name, // Conditionally resolve fully qualified names as type values if the option // to do so is configured in the expression builder. If the type name is // not qualified, then it too may be returned as a constant value. - if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, '.')) { + if (resolve_qualified_type_identifiers_ || !absl::StrContains(name, ".")) { auto type_value = type_registry_->FindType(name); if (type_value.has_value()) { return type_value.value(); diff --git a/eval/compiler/resolver_test.cc b/eval/compiler/resolver_test.cc index 980ebd4ef..08083925e 100644 --- a/eval/compiler/resolver_test.cc +++ b/eval/compiler/resolver_test.cc @@ -9,6 +9,7 @@ #include "eval/public/cel_function_registry.h" #include "eval/public/cel_type_registry.h" #include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" namespace google { namespace api { diff --git a/eval/public/BUILD b/eval/public/BUILD index 4bd638717..2effb7e8e 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -473,8 +473,8 @@ cc_test( ":cel_type_registry", "//eval/testutil:test_message_cc_proto", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_googletest//:gtest_main", + "@com_google_protobuf//:protobuf", ], ) diff --git a/eval/public/cel_type_registry_test.cc b/eval/public/cel_type_registry_test.cc index d90cce2ca..3117722da 100644 --- a/eval/public/cel_type_registry_test.cc +++ b/eval/public/cel_type_registry_test.cc @@ -1,6 +1,6 @@ #include "eval/public/cel_type_registry.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/any.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "absl/container/flat_hash_map.h" @@ -45,9 +45,9 @@ TEST(CelTypeRegistryTest, TestRegisterTypeName) { TEST(CelTypeRegistryTest, TestFindDescriptorFound) { CelTypeRegistry registry; - auto desc = registry.FindDescriptor("google.api.expr.Expr"); + auto desc = registry.FindDescriptor("google.protobuf.Any"); ASSERT_TRUE(desc != nullptr); - EXPECT_THAT(desc->full_name(), Eq("google.api.expr.Expr")); + EXPECT_THAT(desc->full_name(), Eq("google.protobuf.Any")); } TEST(CelTypeRegistryTest, TestFindDescriptorNotFound) { @@ -66,10 +66,10 @@ TEST(CelTypeRegistryTest, TestFindTypeCoreTypeFound) { TEST(CelTypeRegistryTest, TestFindTypeProtobufTypeFound) { CelTypeRegistry registry; - auto type = registry.FindType("google.api.expr.Expr"); + auto type = registry.FindType("google.protobuf.Any"); ASSERT_TRUE(type.has_value()); EXPECT_TRUE(type.value().IsCelType()); - EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("google.api.expr.Expr")); + EXPECT_THAT(type.value().CelTypeOrDie().value(), Eq("google.protobuf.Any")); } TEST(CelTypeRegistryTest, TestFindTypeNotRegisteredTypeNotFound) {