From 92f5d5ea4547aba4f2a69ca12168591330d18702 Mon Sep 17 00:00:00 2001 From: CEL Dev Team Date: Wed, 26 Feb 2020 16:58:03 -0500 Subject: [PATCH] Export of internal changes -- 297443901: BEGIN_PUBLIC Use absl::Status. Add status_macros.h for the explicit status macros import. Make `CelExpressionFlatImpl` non-copyable to avoid a potential issue with its vector of unique pointers. END_PUBLIC Replace `util::Status` with `absl::Status` and canonical codes throughout. Add `static_cast(char)` to get rid of the compiler warning about signed chars. Fully qualify Expr and SourceInfo in a few places to facilitate v1alpha1 rewrite. Change tests using `CelExpressionFlatImpl` since clang-8 is unable to synthesize the move constructor for its vector of unique pointers. Refactor code using `StatusIs` since OSS testing has no status macros. Refactor RegisterBuiltinFunctions into a few functions to reduce its length. -- 297297141: Disable printing of test values This is done in preparation for cl/296247273 which changes the names of the color constants. The intention is to reinstate the additional information display after that has landed. BEGIN_PUBLIC Temporarily disable printing of test values END_PUBLIC -- 297266899: LSC: Add std:: qualifications to all references to std::string and std::basic_string. Adding these qualifications will make google3 C++ more portable, allow the global using std::string declarations in the google3 copy of the C++ standard library to be deleted, and bring google3 C++ in line with how the rest of the world uses C++. This completes the work started in go/lsc-add-std and go/std-type-qualification. LSC documentation: go/lsc-std-string Tested: tap_presubmit: http://test/OCL:296203366:BASE:296180783:1582224434313:c3cf9bf1 Some tests failed; test failures are believed to be unrelated to this CL -- 296969963: Mock CEL object -- 296958295: LSC: Replace gtl swiss table forwarding headers with their //third_party/absl/container equivalents. Background: gtl/util/(node|flat)_hash_(map|set).h are deprecated forwarding headers. This CL cleans up references to those header files. LSC document: go/absl-cleanup-lsc Tested: TAP --sample ran all affected tests and none failed http://test/OCL:296524469:BASE:296537645:1582335970682:be92947f -- 296507025: BEGIN_PUBLIC Note that interpreter tracing behaves differently when the constant folding is enabled. END_PUBLIC -- 296273889: BEGIN_PUBLIC Add end to end tests for unknown functions and attributes. Add benchmarks (mostly characterizing worst-case merge performance). END_PUBLIC -- 296273780: BEGIN_PUBLIC Add support for unknown function results. END_PUBLIC -- 296269903: BEGIN_PUBLIC Add option for unknown function results. Final implementation in next change. END_PUBLIC -- 296236495: BEGIN_PUBLIC Add support for nested comprehensions. Before this, reusing the accumulator or iter var in a nested comprehension could put the evaluation frame in a bad state and cause a status. END_PUBLIC -- 295763189: BEGIN_PUBLIC Update references to unknownattribute set that should refer to an unknown set. END_PUBLIC -- 295056234: BEGIN_PUBLIC Add UnknownSet type -- this just coordinates updating the underlying attribute set and function results set. END_PUBLIC -- 295026542: BEGIN_PUBLIC Add UnknownFunctionResult and UnknownFunctionResultSet types. END_PUBLIC -- 294785557: Project import generated by Copybara. -- 293733801: absl::GetWeekday() no longer requires an absl::CivilDay argument. Any absl::Civil* variety will do ... they all have a weekday. Tested: tap_presubmit: http://test/OCL:293299235:BASE:293282957:1580966223484:4091c83f Some tests failed; test failures are believed to be unrelated to this CL -- 293705273: BEGIN_PUBLIC Adding test that verify behavior of CEL ternary operator for different cases (true/false/error/unknown) END_PUBLIC -- 292626176: BEGIN_PUBLIC Fix ConstantFoldingTransform bug. END_PUBLIC bugfix: add struct_expr.message_name in ConstantFoldingTransform -- 292579803: Updating Guitar workflow BUILD target dependencies to reflect recent runs. go/prax is a tool that discovers missing dependencies in Guitar workflow BUILD targets by analyzing recent workflow runs. This CL was automatically created by go/prax to fix some of your BUILD files. go/prax found an update to this BUILD file to more closely match dependencies. NOTICE: This CL adds dependencies to one or more Guitar workflow BUILD targets. Please be aware that while workflows without dependencies always run on presubmits, once dependencies are added, Guitar will only run them when affected by the pending CL. Missing dependencies may lead to workflows not running when they should. IMPORTANT: Not all dependencies can be extracted from a Guitar invocation. Here is a list of targets that are not available in previous Guitar invocations: * The SandMan GCL definition (because not a target). This can be added manually using a borgcfg_library target with all the SandMan GCL files. * All the MPMs that are just fetched (not built) by the Borg jobs started by SandMan. The genmpm targets can be added manually. Make sure to manually add all the missing dependencies to your workflow or Guitar will skip the test at presubmit, even though the CL may cause the test to fail. For more details: go/guitar-3-deps This is based on your BUILD file @ cl/292377696 If you do not think this is correct, please file a bug @ go/prax-file-a-bug This CL looks good? Just LGTM and Approve it! What else can you do? * Suggest a fix on the CL (go/how-to-suggest-fix). * Revert this CL, by replying "REVERT: " * Reassign it to a more suitable reviewer with sufficient approval rights. * Set enable_auto_deps=False in your BUILD target (go/gbe#common-attributes-behavior) to stop receiving PRAX updates for this BUILD target. -- 292565730: BEGIN_PUBLIC Separate out AttributeTrail and UnknownsUtility into their own build targets. END_PUBLIC clang-migrate spec: old_header: "third_party/cel/cpp/eval/eval/evaluator_core.h" new_header: "third_party/cel/cpp/eval/eval/unknowns_utility.h" old_ns: "google::api::expr::runtime" new_ns: "google::api::expr::runtime" renames { kind: CLASS old_symbol: "google::api::expr::runtime::UnknownsUtility" new_symbol: "google::api::expr::runtime::UnknownsUtiltiy" } old_depend_on_new: true clang-migrate spec: old_header: "third_party/cel/cpp/eval/eval/evaluator_core.h" new_header: "third_party/cel/cpp/eval/eval/attribute_trail.h" old_ns: "google::api::expr::runtime" new_ns: "google::api::expr::runtime" renames { kind: CLASS old_symbol: "google::api::expr::runtime::AttributeTrail" new_symbol: "google::api::expr::runtime::AttributeTrail" } old_depend_on_new: true -- 292387727: Parse CEL AST expression and extract destination IP prefix set BEGIN_PUBLIC Internal change END_PUBLIC - Arcus sends CEL AST expression to NATM. - This parser will extract destination IP prefix set from the AST expression. -- 291997941: BEGIN_PUBLIC Adding Unknowns support to CEL C++ Evaluator - Adding unknown_patterns to Activation - Apply changes to evaluator core to support unknowns: - attributes stored in value stack along with corresponding values; - patterns stored in execution frame - Unknowns support in ExecutionStep implementations. END_PUBLIC -- 291392455: BEGIN_PUBLIC CEL C++ interpreter: add the missing int->string conversion function. END_PUBLIC -- 290287046: Update form definition compilation script to parse expression using the CEL cpp library instead of parsing & checking it via the expr service. BEGIN_PUBLIC Internal change END_PUBLIC -- 287599493: BEGIN_PUBLIC Add option for opting out of treating warnings as fatal errors. END_PUBLIC -- 286093278: BEGIN_PUBLIC Expose evaluation state from Expression parsing. CelExpression can now initialize and use an opaque "CelEvaluationState". This type holds all the mutable state used by CelExpression during the Trace and Evaluate calls. Instead of allocating and initializing this state each time Evalue or Trace is called, the user can elect to Initialize a CelEvaluationState prior to those calls and pass it into the newly provided overrides. This allows the user to remove these allocations when the CelExpression is called frequently under certain conditions (performance, realtime). END_PUBLIC -- 286061190: BEGIN_PUBLIC Remove active allocation paths from ExecutionFrame IterVars. END_PUBLIC The current method of setting the variable's values (and clearing them) involves adding/removing from a map. Instead, the expression builder now tracks all variable names used in the expression, and those are used to initialize the frame's map. During evaluation, the variables are now set on the existing entry in the map, resulting in no allocations from the map during these operations. A "cleared" value is represented by a default value CelValue, not a non-existing map entry. (allocations from CelValue::operator=() still need to be addressed). -- 286039529: BEGIN_PUBLIC Remove active allocation paths from ValueStack. END_PUBLIC The ValueStack no longer directly manipulates a vector, but treats it as a fixed size array and just adjusts values up and down as it inserts them. -- 285120619: Remove unused absl::string_view variables. For various weird reasons, Clang does not currently warn about unused absl::string_view objects (because it thinks the constructor may have side-effects). In order to make absl::string_view interchangable with std::string_view, this needs to be cleaned up. This patch removes all unused string_view objects. It /does not/ attempt to fix any underlying bugs signaled by the unused variable. If you have a better fix, please suggest it. Tested: TAP --sample ran all affected tests and none failed http://test/OCL:285082973:BASE:285111449:1576127415846:786161bf -- 285016935: BEGIN_PUBLIC Add method for getting build warnings associated with a built expression. END_PUBLIC -- 284968502: BEGIN_PUBLIC Fix the naming for AddResolvableEnum/RemoveResolvableEnum methods in CelExpressionBuilder END_PUBLIC -- 284848907: BEGIN_PUBLIC Demote function resolution errors at evaluation time to a warning. Client must opt out of treating warning as a fatal error (maintains current behavior). END_PUBLIC The use case for this is for clients that may have slower rollouts than the reference environment. (e.g. expression and compiler expects all evalutors to have a new function but some clients lag and don't add the new function to their registries yet.) -- 284618467: BEGIN_PUBLIC Squash public notes on export. END_PUBLIC -- 284594209: BEGIN_PUBLIC Removed direct access to ExecutionFrame iteration variable map. END_PUBLIC This change allows us to revamp the inner workings of the ExecutionFrame's iterator variable map without exposing those changes to client code of the class. PiperOrigin-RevId: 297443901 --- base/BUILD | 20 +- base/canonical_errors.cc | 137 ---- base/canonical_errors.h | 71 -- base/status.cc | 97 --- base/status.h | 196 ----- base/status_macros.h | 53 ++ base/statusor.cc | 12 +- base/statusor.h | 29 +- base/statusor_internals.h | 24 +- common/escaping.cc | 8 +- common/escaping_test.cc | 2 + common/operators.cc | 176 ++-- common/type.cc | 3 +- common/value.h | 8 +- common/value_test.cc | 3 +- conformance/BUILD | 10 +- eval/compiler/BUILD | 39 +- eval/compiler/constant_folding.cc | 9 +- eval/compiler/constant_folding_test.cc | 66 +- eval/compiler/flat_expr_builder.cc | 102 ++- eval/compiler/flat_expr_builder.h | 31 +- .../flat_expr_builder_comprehensions_test.cc | 180 +++++ eval/compiler/flat_expr_builder_test.cc | 302 ++++++- eval/eval/BUILD | 178 +++- eval/eval/attribute_trail.cc | 25 + eval/eval/attribute_trail.h | 63 ++ eval/eval/attribute_trail_test.cc | 41 + eval/eval/comprehension_step.cc | 47 +- eval/eval/comprehension_step.h | 6 +- eval/eval/const_value_step.cc | 6 +- eval/eval/const_value_step_test.cc | 23 +- eval/eval/container_access_step.cc | 93 ++- eval/eval/container_access_step_test.cc | 212 +++-- eval/eval/create_list_step.cc | 35 +- eval/eval/create_list_step_test.cc | 107 ++- eval/eval/create_struct_step.cc | 94 ++- eval/eval/create_struct_step_test.cc | 232 +++--- eval/eval/evaluator_core.cc | 134 +++- eval/eval/evaluator_core.h | 248 ++++-- eval/eval/evaluator_core_test.cc | 130 ++- eval/eval/expression_build_warning.cc | 19 + eval/eval/expression_build_warning.h | 36 + eval/eval/expression_build_warning_test.cc | 36 + eval/eval/field_access.cc | 45 +- eval/eval/field_access.h | 24 +- eval/eval/function_step.cc | 97 ++- eval/eval/function_step.h | 4 +- eval/eval/function_step_test.cc | 757 ++++++++++++++++-- eval/eval/ident_step.cc | 76 +- eval/eval/ident_step_test.cc | 71 +- eval/eval/jump_step.cc | 24 +- eval/eval/jump_step.h | 4 +- eval/eval/logic_step.cc | 44 +- eval/eval/logic_step_test.cc | 319 ++++++++ eval/eval/select_step.cc | 93 ++- eval/eval/select_step_test.cc | 318 ++++++-- eval/eval/unknowns_utility.cc | 96 +++ eval/eval/unknowns_utility.h | 68 ++ eval/eval/unknowns_utility_test.cc | 145 ++++ eval/public/BUILD | 87 +- eval/public/activation.cc | 9 +- eval/public/activation.h | 30 +- eval/public/activation_bind_helper.cc | 12 +- eval/public/activation_bind_helper.h | 6 +- eval/public/activation_bind_helper_test.cc | 10 +- eval/public/activation_test.cc | 13 +- eval/public/builtin_func_registrar.cc | 517 ++++++------ eval/public/builtin_func_registrar.h | 2 +- eval/public/builtin_func_test.cc | 32 +- eval/public/cel_expr_builder_factory.cc | 15 + eval/public/cel_expression.h | 40 +- eval/public/cel_function.h | 6 +- eval/public/cel_function_adapter.h | 81 +- eval/public/cel_function_adapter_test.cc | 19 +- eval/public/cel_function_provider.cc | 2 +- eval/public/cel_function_provider_test.cc | 15 +- eval/public/cel_function_registry.cc | 16 +- eval/public/cel_function_registry.h | 6 +- eval/public/cel_function_registry_test.cc | 23 +- eval/public/cel_options.h | 20 + eval/public/cel_value.cc | 42 +- eval/public/cel_value.h | 31 +- eval/public/cel_value_test.cc | 18 +- eval/public/extension_func_registrar.cc | 4 +- eval/public/extension_func_registrar.h | 2 +- eval/public/extension_func_test.cc | 15 +- eval/public/unknown_attribute_set.h | 24 +- eval/public/unknown_attribute_set_test.cc | 6 +- eval/public/unknown_function_result_set.cc | 161 ++++ eval/public/unknown_function_result_set.h | 75 ++ .../unknown_function_result_set_test.cc | 487 +++++++++++ eval/public/unknown_set.h | 52 ++ eval/public/unknown_set_test.cc | 129 +++ eval/public/value_export_util.cc | 23 +- eval/public/value_export_util.h | 7 +- eval/public/value_export_util_test.cc | 43 +- eval/tests/BUILD | 38 +- eval/tests/end_to_end_test.cc | 13 +- eval/tests/mock_cel_expression.h | 44 + eval/tests/unknowns_end_to_end_test.cc | 305 +++++++ internal/cel_printer.h | 8 +- internal/hash_util.cc | 4 +- parser/BUILD | 3 +- parser/parser.cc | 7 +- parser/parser_test.cc | 3 + protoutil/type_registry.cc | 5 +- testutil/test_data_io.cc | 4 +- testutil/test_data_util.cc | 2 +- 108 files changed, 6170 insertions(+), 1874 deletions(-) delete mode 100644 base/canonical_errors.cc delete mode 100644 base/canonical_errors.h delete mode 100644 base/status.cc delete mode 100644 base/status.h create mode 100644 base/status_macros.h create mode 100644 eval/compiler/flat_expr_builder_comprehensions_test.cc create mode 100644 eval/eval/attribute_trail.cc create mode 100644 eval/eval/attribute_trail.h create mode 100644 eval/eval/attribute_trail_test.cc create mode 100644 eval/eval/expression_build_warning.cc create mode 100644 eval/eval/expression_build_warning.h create mode 100644 eval/eval/expression_build_warning_test.cc create mode 100644 eval/eval/logic_step_test.cc create mode 100644 eval/eval/unknowns_utility.cc create mode 100644 eval/eval/unknowns_utility.h create mode 100644 eval/eval/unknowns_utility_test.cc create mode 100644 eval/public/unknown_function_result_set.cc create mode 100644 eval/public/unknown_function_result_set.h create mode 100644 eval/public/unknown_function_result_set_test.cc create mode 100644 eval/public/unknown_set.h create mode 100644 eval/public/unknown_set_test.cc create mode 100644 eval/tests/mock_cel_expression.h create mode 100644 eval/tests/unknowns_end_to_end_test.cc diff --git a/base/BUILD b/base/BUILD index 8b92d6fc2..2ddce9ff0 100644 --- a/base/BUILD +++ b/base/BUILD @@ -5,26 +5,26 @@ package( ) cc_library( - name = "status", + name = "statusor", srcs = [ - "canonical_errors.cc", - "status.cc", "statusor.cc", ], hdrs = [ - "canonical_errors.h", - "status.h", "statusor.h", "statusor_internals.h", ], copts = ["-std=c++14"], deps = [ - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/status", ], ) + +cc_library( + name = "status_macros", + hdrs = [ + "status_macros.h", + ], + copts = ["-std=c++14"], +) diff --git a/base/canonical_errors.cc b/base/canonical_errors.cc deleted file mode 100644 index 4588e0f88..000000000 --- a/base/canonical_errors.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "base/canonical_errors.h" - -namespace cel_base { - -Status AbortedError(absl::string_view message) { - return Status(ABORTED, message); -} - -Status AlreadyExistsError(absl::string_view message) { - return Status(ALREADY_EXISTS, message); -} - -Status CancelledError(absl::string_view message) { - return Status(CANCELLED, message); -} - -Status DataLossError(absl::string_view message) { - return Status(DATA_LOSS, message); -} - -Status DeadlineExceededError(absl::string_view message) { - return Status(DEADLINE_EXCEEDED, message); -} - -Status FailedPreconditionError(absl::string_view message) { - return Status(FAILED_PRECONDITION, message); -} - -Status InternalError(absl::string_view message) { - return Status(INTERNAL, message); -} - -Status InvalidArgumentError(absl::string_view message) { - return Status(INVALID_ARGUMENT, message); -} - -Status NotFoundError(absl::string_view message) { - return Status(NOT_FOUND, message); -} - -Status OutOfRangeError(absl::string_view message) { - return Status(OUT_OF_RANGE, message); -} - -Status PermissionDeniedError(absl::string_view message) { - return Status(PERMISSION_DENIED, message); -} - -Status ResourceExhaustedError(absl::string_view message) { - return Status(RESOURCE_EXHAUSTED, message); -} - -Status UnauthenticatedError(absl::string_view message) { - return Status(UNAUTHENTICATED, message); -} - -Status UnavailableError(absl::string_view message) { - return Status(UNAVAILABLE, message); -} - -Status UnimplementedError(absl::string_view message) { - return Status(UNIMPLEMENTED, message); -} - -Status UnknownError(absl::string_view message) { - return Status(UNKNOWN, message); -} - -bool IsAborted(const Status& status) { return status.code() == ABORTED; } - -bool IsAlreadyExists(const Status& status) { - return status.code() == ALREADY_EXISTS; -} - -bool IsCancelled(const Status& status) { return status.code() == CANCELLED; } - -bool IsDataLoss(const Status& status) { return status.code() == DATA_LOSS; } - -bool IsDeadlineExceeded(const Status& status) { - return status.code() == DEADLINE_EXCEEDED; -} - -bool IsFailedPrecondition(const Status& status) { - return status.code() == FAILED_PRECONDITION; -} - -bool IsInternal(const Status& status) { return status.code() == INTERNAL; } - -bool IsInvalidArgument(const Status& status) { - return status.code() == INVALID_ARGUMENT; -} - -bool IsNotFound(const Status& status) { return status.code() == NOT_FOUND; } - -bool IsOutOfRange(const Status& status) { - return status.code() == OUT_OF_RANGE; -} - -bool IsPermissionDenied(const Status& status) { - return status.code() == PERMISSION_DENIED; -} - -bool IsResourceExhausted(const Status& status) { - return status.code() == RESOURCE_EXHAUSTED; -} - -bool IsUnauthenticated(const Status& status) { - return status.code() == UNAUTHENTICATED; -} - -bool IsUnavailable(const Status& status) { - return status.code() == UNAVAILABLE; -} - -bool IsUnimplemented(const Status& status) { - return status.code() == UNIMPLEMENTED; -} - -bool IsUnknown(const Status& status) { return status.code() == UNKNOWN; } - -} // namespace cel_base diff --git a/base/canonical_errors.h b/base/canonical_errors.h deleted file mode 100644 index a498f8172..000000000 --- a/base/canonical_errors.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef THIRD_PARTY_CEL_CPP_BASE_CANONICAL_ERRORS_H_ -#define THIRD_PARTY_CEL_CPP_BASE_CANONICAL_ERRORS_H_ - -// This file declares a set of functions for working with Status objects from -// the canonical errors. There are functions to easily generate such -// status object and function for classifying them. -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/strings/string_view.h" -#include "base/status.h" - -namespace cel_base { - -// Each of the functions below creates a canonical error with the given -// message. The error code of the returned status object matches the name of -// the function. -Status AbortedError(absl::string_view message); -Status AlreadyExistsError(absl::string_view message); -Status CancelledError(absl::string_view message); -Status DataLossError(absl::string_view message); -Status DeadlineExceededError(absl::string_view message); -Status FailedPreconditionError(absl::string_view message); -Status InternalError(absl::string_view message); -Status InvalidArgumentError(absl::string_view message); -Status NotFoundError(absl::string_view message); -Status OutOfRangeError(absl::string_view message); -Status PermissionDeniedError(absl::string_view message); -Status ResourceExhaustedError(absl::string_view message); -Status UnauthenticatedError(absl::string_view message); -Status UnavailableError(absl::string_view message); -Status UnimplementedError(absl::string_view message); -Status UnknownError(absl::string_view message); - -// Each of the functions below returns true if the given status matches the -// canonical error code implied by the function's name. -ABSL_MUST_USE_RESULT bool IsAborted(const Status& status); -ABSL_MUST_USE_RESULT bool IsAlreadyExists(const Status& status); -ABSL_MUST_USE_RESULT bool IsCancelled(const Status& status); -ABSL_MUST_USE_RESULT bool IsDataLoss(const Status& status); -ABSL_MUST_USE_RESULT bool IsDeadlineExceeded(const Status& status); -ABSL_MUST_USE_RESULT bool IsFailedPrecondition(const Status& status); -ABSL_MUST_USE_RESULT bool IsInternal(const Status& status); -ABSL_MUST_USE_RESULT bool IsInvalidArgument(const Status& status); -ABSL_MUST_USE_RESULT bool IsNotFound(const Status& status); -ABSL_MUST_USE_RESULT bool IsOutOfRange(const Status& status); -ABSL_MUST_USE_RESULT bool IsPermissionDenied(const Status& status); -ABSL_MUST_USE_RESULT bool IsResourceExhausted(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnauthenticated(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnavailable(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnimplemented(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnknown(const Status& status); - -} // namespace cel_base - -#endif // THIRD_PARTY_CEL_CPP_BASE_CANONICAL_ERRORS_H_ diff --git a/base/status.cc b/base/status.cc deleted file mode 100644 index ee59645d6..000000000 --- a/base/status.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "base/status.h" - -#include "absl/strings/str_cat.h" -#include "absl/types/optional.h" - -namespace cel_base { - -std::string StatusCodeToString(StatusCode code) { - switch (code) { - case StatusCode::kOk: - return "OK"; - case StatusCode::kCancelled: - return "CANCELLED"; - case StatusCode::kUnknown: - return "UNKNOWN"; - case StatusCode::kInvalidArgument: - return "INVALID_ARGUMENT"; - case StatusCode::kDeadlineExceeded: - return "DEADLINE_EXCEEDED"; - case StatusCode::kNotFound: - return "NOT_FOUND"; - case StatusCode::kAlreadyExists: - return "ALREADY_EXISTS"; - case StatusCode::kPermissionDenied: - return "PERMISSION_DENIED"; - case StatusCode::kUnauthenticated: - return "UNAUTHENTICATED"; - case StatusCode::kResourceExhausted: - return "RESOURCE_EXHAUSTED"; - case StatusCode::kFailedPrecondition: - return "FAILED_PRECONDITION"; - case StatusCode::kAborted: - return "ABORTED"; - case StatusCode::kOutOfRange: - return "OUT_OF_RANGE"; - case StatusCode::kUnimplemented: - return "UNIMPLEMENTED"; - case StatusCode::kInternal: - return "INTERNAL"; - case StatusCode::kUnavailable: - return "UNAVAILABLE"; - case StatusCode::kDataLoss: - return "DATA_LOSS"; - default: - return ""; - } -} - -std::ostream& operator<<(std::ostream& os, StatusCode code) { - return os << StatusCodeToString(code); -} - -Status::Status(StatusCode code, absl::string_view message) - : code_(code), message_(code == StatusCode::kOk ? "" : message) {} - -std::string Status::ToString() const { - return ok() ? "OK" : absl::StrCat(StatusCodeToString(code()), ": ", message_); -} - -void Status::SetPayload(absl::string_view type_url, const StatusCord& payload) { - if (!ok()) { - payload_.try_emplace(std::string(type_url), payload); - } -} - -absl::optional Status::GetPayload( - absl::string_view type_url) const { - auto it = payload_.find(std::string(type_url)); - if (it == payload_.end()) return absl::nullopt; - return it->second; -} - -void Status::ErasePayload(absl::string_view type_url) { - payload_.erase(std::string(type_url)); -} - -std::ostream& operator<<(std::ostream& os, const Status& x) { - return os << x.ToString(); -} - -} // namespace cel_base diff --git a/base/status.h b/base/status.h deleted file mode 100644 index 2e215339e..000000000 --- a/base/status.h +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef THIRD_PARTY_CEL_CPP_BASE_STATUS_H_ -#define THIRD_PARTY_CEL_CPP_BASE_STATUS_H_ - -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/container/node_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" - -namespace cel_base { - -// replace with cel_base::StatusCord once available. -using StatusCord = std::string; - -enum class StatusCode { - kOk = 0, - kCancelled = 1, - kUnknown = 2, - kInvalidArgument = 3, - kDeadlineExceeded = 4, - kNotFound = 5, - kAlreadyExists = 6, - kPermissionDenied = 7, - kResourceExhausted = 8, - kFailedPrecondition = 9, - kAborted = 10, - kOutOfRange = 11, - kUnimplemented = 12, - kInternal = 13, - kUnavailable = 14, - kDataLoss = 15, - kUnauthenticated = 16, - kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 -}; - -std::string StatusCodeToString(StatusCode e); - -std::ostream& operator<<(std::ostream& os, StatusCode code); - -// Handle both for now. This is meant to be _very short lived. Once internal -// code can safely use the new naming, we will switch that that and drop this. -constexpr StatusCode OK = StatusCode::kOk; -constexpr StatusCode CANCELLED = StatusCode::kCancelled; -constexpr StatusCode UNKNOWN = StatusCode::kUnknown; -constexpr StatusCode INVALID_ARGUMENT = StatusCode::kInvalidArgument; -constexpr StatusCode DEADLINE_EXCEEDED = StatusCode::kDeadlineExceeded; -constexpr StatusCode NOT_FOUND = StatusCode::kNotFound; -constexpr StatusCode ALREADY_EXISTS = StatusCode::kAlreadyExists; -constexpr StatusCode PERMISSION_DENIED = StatusCode::kPermissionDenied; -constexpr StatusCode UNAUTHENTICATED = StatusCode::kUnauthenticated; -constexpr StatusCode RESOURCE_EXHAUSTED = StatusCode::kResourceExhausted; -constexpr StatusCode FAILED_PRECONDITION = StatusCode::kFailedPrecondition; -constexpr StatusCode ABORTED = StatusCode::kAborted; -constexpr StatusCode OUT_OF_RANGE = StatusCode::kOutOfRange; -constexpr StatusCode UNIMPLEMENTED = StatusCode::kUnimplemented; -constexpr StatusCode INTERNAL = StatusCode::kInternal; -constexpr StatusCode UNAVAILABLE = StatusCode::kUnavailable; -constexpr StatusCode DATA_LOSS = StatusCode::kDataLoss; - -class ABSL_MUST_USE_RESULT Status; - -class Status final { - public: - // Builds an OK Status. - Status() = default; - - // Constructs a Status object containing a status code and message. - // If `code == StatusCode::kOk`, `msg` is ignored and an object identical to - // an OK status is constructed. - Status(StatusCode code, absl::string_view message); - - // Return the error message (if any). - absl::string_view message() const { return message_; } - - // Returns true if the Status is OK. - ABSL_MUST_USE_RESULT bool ok() const; - - // Deprecated. Use code(). - int error_code() const; - - // Deprecated. Use message(). - std::string error_message() const; - - // Deprecated. Use code(). - StatusCode CanonicalCode() const; - - // If "ok()", does nothing. Else adds the given `payload` specified, by - // `type_url` as an additional payload. - void SetPayload(absl::string_view type_url, const StatusCord& payload); - - bool operator==(const Status& x) const; - bool operator!=(const Status& x) const; - - void Update(const Status& rhs) { - if (ok()) *this = rhs; - } - - // Return a combination of the error code name and message. - // Note, no guarantees are made as to the exact nature of the returned string. - // Subject to change at any time. - std::string ToString() const; - - // Deprecated. Just returns self. - Status ToCanonical() const; - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; - - // Returns the stored status code. - StatusCode code() const { return code_; } - - // Retrieve a single value associated with `type_url`. Returns absl::nullopt - // if no value is associated with `type_url`. - absl::optional GetPayload(const absl::string_view type_url) const; - - // Erase the payload associated with `type_url`, if present. - void ErasePayload(absl::string_view type_url); - - void ForEachPayload( - const std::function& visitor) - const; - - private: - StatusCode code_ = StatusCode::kOk; - std::string message_; - // Structured error payload. String is a 'type_url' for example, a proto - // descriptor full name. - absl::node_hash_map payload_; -}; - -inline bool Status::ok() const { return StatusCode::kOk == code_; } - -inline int Status::error_code() const { return static_cast(code()); } - -inline std::string Status::error_message() const { - return std::string(message()); -} - -inline StatusCode Status::CanonicalCode() const { return code(); } - -inline Status Status::ToCanonical() const { return *this; } - -inline bool Status::operator==(const Status& x) const { - return (code_ == x.code_) && (message_ == x.message_) && - (payload_ == x.payload_); -} - -inline bool Status::operator!=(const Status& x) const { return !(*this == x); } - -inline void Status::IgnoreError() const { - // no-op -} - -inline void Status::ForEachPayload( - const std::function& visitor) - const { - for (auto it = payload_.begin(); it != payload_.end(); ++it) { - visitor(it->first, it->second); - } -} - -// Prints a human-readable representation of 'x' to 'os'. -std::ostream& operator<<(std::ostream& os, const Status& x); - -// Constructs an OK status object. -inline Status OkStatus() { return Status(); } - -// This is better than GOOGLE_CHECK((val).ok()) because the embedded -// error string gets printed by the CHECK_EQ. -#define CHECK_OK(val) \ - ABSL_RAW_CHECK(val == ::cel_base::OkStatus(), "Status not OK") - -} // namespace cel_base - -#endif // THIRD_PARTY_CEL_CPP_BASE_STATUS_H_ diff --git a/base/status_macros.h b/base/status_macros.h new file mode 100644 index 000000000..675fe656a --- /dev/null +++ b/base/status_macros.h @@ -0,0 +1,53 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ + +#include // for use with down_cast<> + +#include + +// Early-returns the status if it is in error; otherwise, proceeds. +// +// The argument expression is guaranteed to be evaluated exactly once. +#define RETURN_IF_ERROR(__status) \ + do { \ + auto _status = __status; \ + if (!_status.ok()) { \ + return _status; \ + } \ + } while (false) + +template // use like this: down_cast(foo); +inline To down_cast(From* f) { // so we only accept pointers + static_assert( + (std::is_base_of::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. + assert(f == nullptr || dynamic_cast(f) != nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(f); +} + +#define ASSERT_OK(expression) ASSERT_TRUE(expression.ok()) +#define EXPECT_OK(expression) EXPECT_TRUE(expression.ok()) + +#endif // THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ diff --git a/base/statusor.cc b/base/statusor.cc index 95386e056..5215922dc 100644 --- a/base/statusor.cc +++ b/base/statusor.cc @@ -18,24 +18,18 @@ #include -#include "base/status.h" - namespace cel_base { namespace statusor_internal { -void Helper::HandleInvalidStatusCtorArg(Status* status) { +void Helper::HandleInvalidStatusCtorArg(absl::Status* status) { const char* kMessage = "An OK status is not a valid constructor argument to StatusOr"; - ABSL_RAW_CHECK(false, kMessage); // In optimized builds, we will fall back to ::util::error::INTERNAL. - *status = Status(StatusCode::kInternal, kMessage); + *status = absl::Status(absl::StatusCode::kInternal, kMessage); } -void Helper::Crash(const Status&) { - ABSL_RAW_CHECK(false, "Attempting to fetch value instead of handling error"); - abort(); -} +void Helper::Crash(const absl::Status&) { abort(); } } // namespace statusor_internal diff --git a/base/statusor.h b/base/statusor.h index 0a26bade2..26f43d208 100644 --- a/base/statusor.h +++ b/base/statusor.h @@ -70,7 +70,6 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" -#include "base/status.h" #include "base/statusor_internals.h" namespace cel_base { @@ -144,8 +143,8 @@ class StatusOr : private statusor_internal::StatusOrData, // REQUIRES: !status.ok(). This requirement is DCHECKed. // In optimized builds, passing OkStatus() here will have the effect // of passing INTERNAL as a fallback. - StatusOr(const Status& status); - StatusOr& operator=(const Status& status); + StatusOr(const absl::Status& status); + StatusOr& operator=(const absl::Status& status); // Similar to the `const T&` overload. // @@ -153,8 +152,8 @@ class StatusOr : private statusor_internal::StatusOrData, StatusOr(T&& value); // RValue versions of the operations declared above. - StatusOr(Status&& status); - StatusOr& operator=(Status&& status); + StatusOr(absl::Status&& status); + StatusOr& operator=(absl::Status&& status); // Returns this->ok() explicit operator bool() const { return ok(); } @@ -164,8 +163,8 @@ class StatusOr : private statusor_internal::StatusOrData, // Returns a reference to our status. If this contains a T, then // returns OkStatus(). - const Status& status() const&; - Status status() &&; + const absl::Status& status() const&; + absl::Status status() &&; // Returns a reference to our current value, or CHECK-fails if !this->ok(). If // you have already checked the status using this->ok() or operator bool(), @@ -237,16 +236,16 @@ class StatusOr : private statusor_internal::StatusOrData, // Implementation details for StatusOr template -StatusOr::StatusOr() : Base(Status(UNKNOWN, "")) {} +StatusOr::StatusOr() : Base(absl::Status(absl::StatusCode::kUnknown, "")) {} template StatusOr::StatusOr(const T& value) : Base(value) {} template -StatusOr::StatusOr(const Status& status) : Base(status) {} +StatusOr::StatusOr(const absl::Status& status) : Base(status) {} template -StatusOr& StatusOr::operator=(const Status& status) { +StatusOr& StatusOr::operator=(const absl::Status& status) { this->Assign(status); return *this; } @@ -255,10 +254,10 @@ template StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} template -StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} +StatusOr::StatusOr(absl::Status&& status) : Base(std::move(status)) {} template -StatusOr& StatusOr::operator=(Status&& status) { +StatusOr& StatusOr::operator=(absl::Status&& status) { this->Assign(std::move(status)); return *this; } @@ -295,12 +294,12 @@ inline StatusOr& StatusOr::operator=(StatusOr&& other) { } template -const Status& StatusOr::status() const& { +const absl::Status& StatusOr::status() const& { return this->status_; } template -Status StatusOr::status() && { - return ok() ? OkStatus() : std::move(this->status_); +absl::Status StatusOr::status() && { + return ok() ? absl::OkStatus() : std::move(this->status_); } template diff --git a/base/statusor_internals.h b/base/statusor_internals.h index fab2bec94..a524f7848 100644 --- a/base/statusor_internals.h +++ b/base/statusor_internals.h @@ -23,7 +23,7 @@ #include "absl/base/attributes.h" #include "absl/meta/type_traits.h" -#include "base/status.h" +#include "absl/status/status.h" namespace cel_base { @@ -32,8 +32,8 @@ namespace statusor_internal { class Helper { public: // Move type-agnostic error handling to the .cc. - static void HandleInvalidStatusCtorArg(Status*); - ABSL_ATTRIBUTE_NORETURN static void Crash(const Status& status); + static void HandleInvalidStatusCtorArg(absl::Status*); + ABSL_ATTRIBUTE_NORETURN static void Crash(const absl::Status& status); }; // Construct an instance of T in `p` through placement new, passing Args... to @@ -100,10 +100,10 @@ class StatusOrData { explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } - explicit StatusOrData(const Status& status) : status_(status) { + explicit StatusOrData(const absl::Status& status) : status_(status) { EnsureNotOk(); } - explicit StatusOrData(Status&& status) : status_(std::move(status)) { + explicit StatusOrData(absl::Status&& status) : status_(std::move(status)) { EnsureNotOk(); } @@ -140,7 +140,7 @@ class StatusOrData { MakeValue(value); } else { MakeValue(value); - status_ = OkStatus(); + status_ = absl::OkStatus(); } } @@ -150,17 +150,17 @@ class StatusOrData { MakeValue(std::move(value)); } else { MakeValue(std::move(value)); - status_ = OkStatus(); + status_ = absl::OkStatus(); } } - void Assign(const Status& status) { + void Assign(const absl::Status& status) { Clear(); status_ = status; EnsureNotOk(); } - void Assign(Status&& status) { + void Assign(absl::Status&& status) { Clear(); status_ = std::move(status); EnsureNotOk(); @@ -175,7 +175,7 @@ class StatusOrData { // Eg. in the copy constructor we use the default constructor of // Status in the ok() path to avoid an extra Ref call. union { - Status status_; + absl::Status status_; }; // data_ is active iff status_.ok()==true @@ -210,8 +210,8 @@ class StatusOrData { // argument. template void MakeStatus(Args&&... args) { - statusor_internal::PlacementNew(&status_, - std::forward(args)...); + statusor_internal::PlacementNew(&status_, + std::forward(args)...); } }; diff --git a/common/escaping.cc b/common/escaping.cc index b7f8b16c4..39aba6d0a 100644 --- a/common/escaping.cc +++ b/common/escaping.cc @@ -27,11 +27,11 @@ inline std::pair unhex(char c) { // least 4 bytes long. Return the number of bytes written. inline int get_utf8(absl::string_view s, char* buffer) { buffer[0] = s[0]; - if (s[0] < 0x80 || s.size() < 2) return 1; + if (static_cast(s[0]) < 0x80 || s.size() < 2) return 1; buffer[1] = s[1]; - if (s[0] < 0xE0 || s.size() < 3) return 2; + if (static_cast(s[0]) < 0xE0 || s.size() < 3) return 2; buffer[2] = s[2]; - if (s[0] < 0xF0 || s.size() < 4) return 3; + if (static_cast(s[0]) < 0xF0 || s.size() < 4) return 3; buffer[3] = s[3]; return 4; } @@ -84,7 +84,7 @@ inline std::tuple unescape_char( char c = s[0]; // 1. Character is not an escape sequence. - if (c >= 0x80 && !is_bytes) { + if (static_cast(c) >= 0x80 && !is_bytes) { char tmp[5]; int len = get_utf8(s, tmp); tmp[len] = '\0'; diff --git a/common/escaping_test.cc b/common/escaping_test.cc index 85cdc7f8b..8275b48ec 100644 --- a/common/escaping_test.cc +++ b/common/escaping_test.cc @@ -66,6 +66,7 @@ class UnescapeTest : public testing::TestWithParam {}; TEST_P(UnescapeTest, Unescape) { const TestInfo& test_info = GetParam(); + /* ::testing::internal::ColoredPrintf(::testing::internal::COLOR_GREEN, "[ ]"); ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, @@ -82,6 +83,7 @@ TEST_P(UnescapeTest, Unescape) { ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, " Expecting ERROR\n"); } + */ auto result = unescape(test_info.I, test_info.is_bytes); if (test_info.O == EXPECT_ERROR) { diff --git a/common/operators.cc b/common/operators.cc index 803409a77..669469c9a 100644 --- a/common/operators.cc +++ b/common/operators.cc @@ -14,85 +14,91 @@ namespace { // Expr to textual mapping, e.g., from "_&&_" to "&&". const std::map& UnaryOperators() { - static std::shared_ptr> unaries_map = [&]() { - auto u = - std::make_shared>(std::map{ - {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}); - return u; - }(); + static std::shared_ptr> unaries_map = + [&]() { + auto u = std::make_shared>( + std::map{ + {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}); + return u; + }(); return *unaries_map; } const std::map& BinaryOperators() { - static std::shared_ptr> binops_map = [&]() { - auto c = std::make_shared>( - std::map{{CelOperator::LOGICAL_OR, "||"}, - {CelOperator::LOGICAL_AND, "&&"}, - {CelOperator::LESS_EQUALS, "<="}, - {CelOperator::LESS, "<"}, - {CelOperator::GREATER_EQUALS, ">="}, - {CelOperator::GREATER, ">"}, - {CelOperator::EQUALS, "=="}, - {CelOperator::NOT_EQUALS, "!="}, - {CelOperator::IN_DEPRECATED, "in"}, - {CelOperator::IN, "in"}, - {CelOperator::ADD, "+"}, - {CelOperator::SUBTRACT, "-"}, - {CelOperator::MULTIPLY, "*"}, - {CelOperator::DIVIDE, "/"}, - {CelOperator::MODULO, "%"}}); - return c; - }(); + static std::shared_ptr> binops_map = + [&]() { + auto c = std::make_shared>( + std::map{ + {CelOperator::LOGICAL_OR, "||"}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LESS, "<"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::GREATER, ">"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::IN, "in"}, + {CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}}); + return c; + }(); return *binops_map; } const std::map& ReverseOperators() { - static std::shared_ptr> operators_map = [&]() { - auto c = - std::make_shared>(std::map{ - {"+", CelOperator::ADD}, - {"-", CelOperator::SUBTRACT}, - {"*", CelOperator::MULTIPLY}, - {"/", CelOperator::DIVIDE}, - {"%", CelOperator::MODULO}, - {"==", CelOperator::EQUALS}, - {"!=", CelOperator::NOT_EQUALS}, - {">", CelOperator::GREATER}, - {">=", CelOperator::GREATER_EQUALS}, - {"<", CelOperator::LESS}, - {"<=", CelOperator::LESS_EQUALS}, - {"&&", CelOperator::LOGICAL_AND}, - {"!", CelOperator::LOGICAL_NOT}, - {"||", CelOperator::LOGICAL_OR}, - {"in", CelOperator::IN}, - }); - return c; - }(); + static std::shared_ptr> operators_map = + [&]() { + auto c = std::make_shared>( + std::map{ + {"+", CelOperator::ADD}, + {"-", CelOperator::SUBTRACT}, + {"*", CelOperator::MULTIPLY}, + {"/", CelOperator::DIVIDE}, + {"%", CelOperator::MODULO}, + {"==", CelOperator::EQUALS}, + {"!=", CelOperator::NOT_EQUALS}, + {">", CelOperator::GREATER}, + {">=", CelOperator::GREATER_EQUALS}, + {"<", CelOperator::LESS}, + {"<=", CelOperator::LESS_EQUALS}, + {"&&", CelOperator::LOGICAL_AND}, + {"!", CelOperator::LOGICAL_NOT}, + {"||", CelOperator::LOGICAL_OR}, + {"in", CelOperator::IN}, + }); + return c; + }(); return *operators_map; } const std::map& Operators() { - static std::shared_ptr> operators_map = [&]() { - auto c = std::make_shared>( - std::map{{CelOperator::ADD, "+"}, - {CelOperator::SUBTRACT, "-"}, - {CelOperator::MULTIPLY, "*"}, - {CelOperator::DIVIDE, "/"}, - {CelOperator::MODULO, "%"}, - {CelOperator::EQUALS, "=="}, - {CelOperator::NOT_EQUALS, "!="}, - {CelOperator::GREATER, ">"}, - {CelOperator::GREATER_EQUALS, ">="}, - {CelOperator::LESS, "<"}, - {CelOperator::LESS_EQUALS, "<="}, - {CelOperator::LOGICAL_AND, "&&"}, - {CelOperator::LOGICAL_NOT, "!"}, - {CelOperator::LOGICAL_OR, "||"}, - {CelOperator::IN, "in"}, - {CelOperator::IN_DEPRECATED, "in"}, - {CelOperator::NEGATE, "-"}}); - return c; - }(); + static std::shared_ptr> operators_map = + [&]() { + auto c = std::make_shared>( + std::map{ + {CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::GREATER, ">"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::LESS, "<"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LOGICAL_NOT, "!"}, + {CelOperator::LOGICAL_OR, "||"}, + {CelOperator::IN, "in"}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::NEGATE, "-"}}); + return c; + }(); return *operators_map; } @@ -102,30 +108,30 @@ const std::map& Precedences() { auto c = std::make_shared>( std::map{{CelOperator::CONDITIONAL, 8}, - {CelOperator::LOGICAL_OR, 7}, + {CelOperator::LOGICAL_OR, 7}, - {CelOperator::LOGICAL_AND, 6}, + {CelOperator::LOGICAL_AND, 6}, - {CelOperator::EQUALS, 5}, - {CelOperator::GREATER, 5}, - {CelOperator::GREATER_EQUALS, 5}, - {CelOperator::IN, 5}, - {CelOperator::LESS, 5}, - {CelOperator::LESS_EQUALS, 5}, - {CelOperator::NOT_EQUALS, 5}, - {CelOperator::IN_DEPRECATED, 5}, + {CelOperator::EQUALS, 5}, + {CelOperator::GREATER, 5}, + {CelOperator::GREATER_EQUALS, 5}, + {CelOperator::IN, 5}, + {CelOperator::LESS, 5}, + {CelOperator::LESS_EQUALS, 5}, + {CelOperator::NOT_EQUALS, 5}, + {CelOperator::IN_DEPRECATED, 5}, - {CelOperator::ADD, 4}, - {CelOperator::SUBTRACT, 4}, + {CelOperator::ADD, 4}, + {CelOperator::SUBTRACT, 4}, - {CelOperator::DIVIDE, 3}, - {CelOperator::MODULO, 3}, - {CelOperator::MULTIPLY, 3}, + {CelOperator::DIVIDE, 3}, + {CelOperator::MODULO, 3}, + {CelOperator::MULTIPLY, 3}, - {CelOperator::LOGICAL_NOT, 2}, - {CelOperator::NEGATE, 2}, + {CelOperator::LOGICAL_NOT, 2}, + {CelOperator::NEGATE, 2}, - {CelOperator::INDEX, 1}}); + {CelOperator::INDEX, 1}}); return c; }(); return *precedence_map; diff --git a/common/type.cc b/common/type.cc index d0b02637f..ad9fa0ddc 100644 --- a/common/type.cc +++ b/common/type.cc @@ -79,7 +79,8 @@ absl::string_view UnrecognizedType::full_name() const { return absl::string_view(string_rep_).substr(6, string_rep_.size() - 8); } -Type::Type(const std::string& full_name) : data_(BasicType(BasicTypeValue::kNull)) { +Type::Type(const std::string& full_name) + : data_(BasicType(BasicTypeValue::kNull)) { auto itr = kBasicTypeMap->find(full_name); if (itr != kBasicTypeMap->end()) { data_ = itr->second; diff --git a/common/value.h b/common/value.h index 944be76e4..777021a4b 100644 --- a/common/value.h +++ b/common/value.h @@ -506,14 +506,18 @@ class Object : public Container { Value Value::FromDouble(double value) { return Create(value); } Value Value::FromString(absl::string_view value) { return Create(value); } -Value Value::FromString(const std::string& value) { return Create(value); } +Value Value::FromString(const std::string& value) { + return Create(value); +} Value Value::FromString(std::string&& value) { return Create(std::move(value)); } Value Value::FromBytes(absl::string_view value) { return Create(value); } -Value Value::FromBytes(const std::string& value) { return Create(value); } +Value Value::FromBytes(const std::string& value) { + return Create(value); +} Value Value::FromBytes(std::string&& value) { return Create(std::move(value)); } diff --git a/common/value_test.cc b/common/value_test.cc index 2c4def0d4..3a6a60c3a 100644 --- a/common/value_test.cc +++ b/common/value_test.cc @@ -118,7 +118,8 @@ struct ValueTestCase { return Value::Kind::kNull; } - static ValueTestCase ForInline(const Value& value, const std::string& debug_string, + static ValueTestCase ForInline(const Value& value, + const std::string& debug_string, const std::string& type_debug_string) { return ValueTestCase{value, value, value.kind(), value, debug_string, type_debug_string, diff --git a/conformance/BUILD b/conformance/BUILD index 2a402ceb0..2aacf821a 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -33,7 +33,7 @@ cc_binary( srcs = ["server.cc"], copts = ["-std=c++14"], deps = [ - "//base:status", + "//base:statusor", "//eval/eval:container_backed_list_impl", "//eval/eval:container_backed_map_impl", "//eval/public:builtin_func_registrar", @@ -59,14 +59,18 @@ cc_binary( "--check_server=$(location @com_google_cel_go//server/main:cel_server)", # Requires container support "--skip_test=basic/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", + "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", # Requires heteregenous equality spec clarification "--skip_test=comparisons/eq_literal/eq_bytes", "--skip_test=comparisons/ne_literal/not_ne_bytes", "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", + "--skip_test=fields/in/singleton", # Requires qualified bindings error message relaxation "--skip_test=fields/qualified_identifier_resolution/ident_with_longest_prefix_check,int64_field_select_unsupported,list_field_select_unsupported,map_key_null,qualified_identifier_resolution_unchecked", "--skip_test=integer_math/int64_math/int64_overflow_positive,int64_overflow_negative,uint64_overflow_positive,uint64_overflow_negative", + "--skip_test=string/size/one_unicode,unicode", + "--skip_test=string/bytes_concat/left_unit", ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", @@ -93,4 +97,8 @@ sh_test( "@com_google_cel_go//server/main:cel_server", "@com_google_cel_spec//tests/simple:simple_test", ] + DASHBOARD_TESTS, + visibility = [ + "//:__subpackages__", + "//third_party/cel:__pkg__", + ], ) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index d0636922f..93eaf8444 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -17,13 +17,14 @@ cc_library( copts = ["-std=c++14"], deps = [ ":constant_folding", - "//base:status", + "//base:status_macros", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", "//eval/eval:container_access_step", "//eval/eval:create_list_step", "//eval/eval:create_struct_step", "//eval/eval:evaluator_core", + "//eval/eval:expression_build_warning", "//eval/eval:function_step", "//eval/eval:ident_step", "//eval/eval:jump_step", @@ -47,9 +48,44 @@ cc_test( copts = ["-std=c++14"], deps = [ ":flat_expr_builder", + "//base:status_macros", "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_builtins", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "//eval/testutil:test_message_cc_proto", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "flat_expr_builder_comprehensions_test", + srcs = [ + "flat_expr_builder_comprehensions_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":flat_expr_builder", + "//base:status_macros", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_builtins", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", @@ -86,6 +122,7 @@ cc_test( copts = ["-std=c++14"], deps = [ ":constant_folding", + "//base:status_macros", "//eval/public:builtin_func_registrar", "//eval/public:cel_function_registry", "//eval/testutil:test_message_cc_proto", diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index f9ccce95d..970ba7107 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -150,8 +150,8 @@ class ConstantFoldingTransform { } case Expr::kStructExpr: { auto struct_expr = out->mutable_struct_expr(); + struct_expr->set_message_name(expr.struct_expr().message_name()); int entries_size = expr.struct_expr().entries_size(); - bool all_constant = true; for (int i = 0; i < entries_size; i++) { auto& entry = expr.struct_expr().entries(i); auto new_entry = struct_expr->add_entries(); @@ -161,16 +161,13 @@ class ConstantFoldingTransform { new_entry->set_field_key(entry.field_key()); break; case Expr::CreateStruct::Entry::kMapKey: - all_constant = - Transform(entry.map_key(), new_entry->mutable_map_key()) && - all_constant; + Transform(entry.map_key(), new_entry->mutable_map_key()); break; default: GOOGLE_LOG(ERROR) << "Unsupported Entry kind: " << entry.key_kind_case(); break; } - all_constant = Transform(entry.value(), new_entry->mutable_value()) && - all_constant; + Transform(entry.value(), new_entry->mutable_value()); } return false; } diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index facfd7986..f7e24e7e4 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -8,6 +8,8 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" #include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" + namespace google { namespace api { namespace expr { @@ -43,6 +45,62 @@ TEST(ConstantFoldingTest, Select) { EXPECT_TRUE(idents.empty()); } +// Validate struct message creation +TEST(ConstantFoldingTest, StructMessage) { + Expr expr; + // {"field1": "y", "field2": "t"} + google::protobuf::TextFormat::ParseFromString( + R"pb( + id: 5 + struct_expr { + entries { + id: 11 + field_key: "field1" + value { const_expr { string_value: "value1" } } + } + entries { + id: 7 + field_key: "field2" + value { const_expr { int64_value: 12 } } + } + message_name: "MyProto" + })pb", + &expr); + + google::protobuf::Arena arena; + CelFunctionRegistry registry; + + absl::flat_hash_map idents; + Expr out; + FoldConstants(expr, registry, &arena, idents, &out); + + Expr expected; + google::protobuf::TextFormat::ParseFromString(R"( + id: 5 + struct_expr { + entries { + id: 11 + field_key: "field1" + value { ident_expr { name: "$v0" } } + } + entries { + id: 7 + field_key: "field2" + value { ident_expr { name: "$v1" } } + } + message_name: "MyProto" + })", + &expected); + google::protobuf::util::MessageDifferencer md; + EXPECT_TRUE(md.Compare(out, expected)) << out.DebugString(); + + EXPECT_EQ(idents.size(), 2); + EXPECT_TRUE(idents["$v0"].IsString()); + EXPECT_EQ(idents["$v0"].StringOrDie().value(), "value1"); + EXPECT_TRUE(idents["$v1"].IsInt64()); + EXPECT_EQ(idents["$v1"].Int64OrDie(), 12); +} + // Validate struct creation is not folded but recursed into TEST(ConstantFoldingTest, StructComprehension) { Expr expr; @@ -151,7 +209,7 @@ TEST(ConstantFoldingTest, LogicApplication) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; @@ -184,7 +242,7 @@ TEST(ConstantFoldingTest, FunctionApplication) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; @@ -218,7 +276,7 @@ TEST(ConstantFoldingTest, FunctionApplicationWithReceiver) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; @@ -251,7 +309,7 @@ TEST(ConstantFoldingTest, FunctionApplicationNoOverload) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a856afc89..d2cd46209 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -11,6 +11,7 @@ #include "eval/eval/create_list_step.h" #include "eval/eval/create_struct_step.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" #include "eval/eval/function_step.h" #include "eval/eval/ident_step.h" #include "eval/eval/jump_step.h" @@ -20,7 +21,6 @@ #include "eval/public/ast_visitor.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_registry.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -49,14 +49,19 @@ class FlatExprVisitor : public AstVisitor { const std::set& enums, absl::string_view container, const absl::flat_hash_map& constant_idents, - bool enable_comprehension) + bool enable_comprehension, BuilderWarnings* warnings, + std::set* iter_variable_names) : flattened_path_(path), - progress_status_(cel_base::OkStatus()), + progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), function_registry_(function_registry), shortcircuiting_(shortcircuiting), constant_idents_(constant_idents), - enable_comprehension_(enable_comprehension) { + enable_comprehension_(enable_comprehension), + builder_warnings_(warnings), + iter_variable_names_(iter_variable_names) { + GOOGLE_CHECK(iter_variable_names_); + auto container_elements = absl::StrSplit(container, '.'); // Build list of prefixes from container. Non-empty prefixes must end with @@ -121,7 +126,7 @@ class FlatExprVisitor : public AstVisitor { if (value.has_value()) { AddStep(CreateConstValueStep(value.value(), expr->id())); } else { - SetProgressStatusError(cel_base::Status(cel_base::StatusCode::kInvalidArgument, + SetProgressStatusError(absl::Status(absl::StatusCode::kInvalidArgument, "Unsupported constant type")); } } @@ -166,7 +171,7 @@ class FlatExprVisitor : public AstVisitor { if (resolved_select_expr_) { if (!resolved_select_expr_->has_select_expr()) { - progress_status_ = cel_base::InternalError("Unexpected Expr type"); + progress_status_ = absl::InternalError("Unexpected Expr type"); return; } AddStep(CreateConstValueStep(value_desc, resolved_select_expr_->id())); @@ -201,9 +206,9 @@ class FlatExprVisitor : public AstVisitor { } // Check if we are "in the middle" of namespaced name. - // This is currently enum specific. Constant expression that corresponds to - // resolved enum value has been already created, thus preceding chain of - // selects is no longer relevant. + // This is currently enum specific. Constant expression that corresponds + // to resolved enum value has been already created, thus preceding chain + // of selects is no longer relevant. if (resolved_select_expr_) { if (expr == resolved_select_expr_) { resolved_select_expr_ = nullptr; @@ -267,7 +272,8 @@ class FlatExprVisitor : public AstVisitor { return; } // For regular functions, just create one based on registry. - AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_)); + AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_, + builder_warnings_)); } } @@ -277,7 +283,7 @@ class FlatExprVisitor : public AstVisitor { return; } if (!enable_comprehension_) { - SetProgressStatusError(cel_base::Status(cel_base::StatusCode::kInvalidArgument, + SetProgressStatusError(absl::Status(absl::StatusCode::kInvalidArgument, "Comprehension support is disabled")); } cond_visitor_stack_.emplace(expr, @@ -287,7 +293,8 @@ class FlatExprVisitor : public AstVisitor { } // Invoked after all child nodes are processed. - void PostVisitComprehension(const Comprehension*, const Expr* expr, + void PostVisitComprehension(const Comprehension* comprehension_expr, + const Expr* expr, const SourcePosition*) override { if (!progress_status_.ok()) { return; @@ -295,6 +302,15 @@ class FlatExprVisitor : public AstVisitor { auto cond_visitor = FindCondVisitor(expr); cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); + + // Save off the names of the variables we're using, such that we have a + // full set of the names from the entire evaluation tree at the end. + if (!comprehension_expr->accu_var().empty()) { + iter_variable_names_->insert(comprehension_expr->accu_var()); + } + if (!comprehension_expr->iter_var().empty()) { + iter_variable_names_->insert(comprehension_expr->iter_var()); + } } // Invoked after each argument node processed. @@ -331,7 +347,7 @@ class FlatExprVisitor : public AstVisitor { AddStep(CreateCreateStructStep(struct_expr, expr->id())); } - cel_base::Status progress_status() const { return progress_status_; } + absl::Status progress_status() const { return progress_status_; } private: class CondVisitor { @@ -363,6 +379,8 @@ class FlatExprVisitor : public AstVisitor { }; // Visitor managing the "?" operation. + // TODO(issues/41) Make sure Unknowns are properly supported by ternary + // operation. class TernaryCondVisitor : public CondVisitor { public: explicit TernaryCondVisitor(FlatExprVisitor* visitor) @@ -382,10 +400,7 @@ class FlatExprVisitor : public AstVisitor { class ComprehensionVisitor : public CondVisitor { public: explicit ComprehensionVisitor(FlatExprVisitor* visitor) - : CondVisitor(visitor) - , next_step_(nullptr) - , cond_step_(nullptr) - {} + : CondVisitor(visitor), next_step_(nullptr), cond_step_(nullptr) {} void PreVisit(const Expr* expr) override; void PostVisitArg(int arg_num, const Expr* expr) override; @@ -414,7 +429,7 @@ class FlatExprVisitor : public AstVisitor { } } - void SetProgressStatusError(const cel_base::Status& status) { + void SetProgressStatusError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } @@ -433,13 +448,13 @@ class FlatExprVisitor : public AstVisitor { } ExecutionPath* flattened_path_; - cel_base::Status progress_status_; + absl::Status progress_status_; std::stack>> cond_visitor_stack_; - // Maps effective namespace names to Expr objects (IDENTs/SELECTs) that define - // scopes for those namespaces. + // Maps effective namespace names to Expr objects (IDENTs/SELECTs) that + // define scopes for those namespaces. std::unordered_map namespace_map_; // Tracks SELECT-...SELECT-IDENT chains. std::deque> namespace_stack_; @@ -458,6 +473,10 @@ class FlatExprVisitor : public AstVisitor { const absl::flat_hash_map& constant_idents_; bool enable_comprehension_; + + BuilderWarnings* builder_warnings_; + + std::set* iter_variable_names_; }; void FlatExprVisitor::BinaryCondVisitor::PreVisit(const Expr*) {} @@ -529,7 +548,7 @@ void FlatExprVisitor::TernaryCondVisitor::PostVisitArg(int arg_num, if (jump_to_second_.exists()) { jump_to_second_.set_target(visitor_->GetCurrentIndex()); } else { - visitor_->SetProgressStatusError(cel_base::InvalidArgumentError( + visitor_->SetProgressStatusError(absl::InvalidArgumentError( "Error configuring ternary operator: jump_to_second_ is null")); } } @@ -543,14 +562,14 @@ void FlatExprVisitor::TernaryCondVisitor::PostVisit(const Expr*) { if (error_jump_.exists()) { error_jump_.set_target(visitor_->GetCurrentIndex()); } else { - visitor_->SetProgressStatusError(cel_base::InvalidArgumentError( + visitor_->SetProgressStatusError(absl::InvalidArgumentError( "Error configuring ternary operator: error_jump_ is null")); return; } if (jump_after_first_.exists()) { jump_after_first_.set_target(visitor_->GetCurrentIndex()); } else { - visitor_->SetProgressStatusError(cel_base::InvalidArgumentError( + visitor_->SetProgressStatusError(absl::InvalidArgumentError( "Error configuring ternary operator: jump_after_first_ is null")); return; } @@ -623,12 +642,10 @@ void FlatExprVisitor::ComprehensionVisitor::PostVisitArg(int arg_num, visitor_->AddStep(std::move(jump_to_next)); } // Set offsets. - cond_step_->set_jump_offset( - visitor_->GetCurrentIndex() - cond_step_pos_ - 1 - ); - next_step_->set_jump_offset( - visitor_->GetCurrentIndex() - next_step_pos_ - 1 - ); + cond_step_->set_jump_offset(visitor_->GetCurrentIndex() - cond_step_pos_ - + 1); + next_step_->set_jump_offset(visitor_->GetCurrentIndex() - next_step_pos_ - + 1); break; } case RESULT: { @@ -647,11 +664,13 @@ void FlatExprVisitor::ComprehensionVisitor::PostVisit(const Expr*) {} cel_base::StatusOr> FlatExprBuilder::CreateExpression(const Expr* expr, - const SourceInfo* source_info) const { + const SourceInfo* source_info, + std::vector* warnings) const { ExecutionPath execution_path; + BuilderWarnings warnings_builder(fail_on_warnings_); if (absl::StartsWith(container(), ".") || absl::EndsWith(container(), ".")) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("Invalid expression container:", container())); } @@ -663,9 +682,11 @@ FlatExprBuilder::CreateExpression(const Expr* expr, FoldConstants(*expr, *this->GetRegistry(), constant_arena_, idents, &out); } + std::set iter_variable_names; FlatExprVisitor visitor(this->GetRegistry(), &execution_path, shortcircuiting_, resolvable_enums(), container(), - idents, enable_comprehension_); + idents, enable_comprehension_, &warnings_builder, + &iter_variable_names); AstTraverse(constant_folding_ ? &out : expr, source_info, &visitor); @@ -674,12 +695,23 @@ FlatExprBuilder::CreateExpression(const Expr* expr, } std::unique_ptr expression_impl = - absl::make_unique(expr, std::move(execution_path), - comprehension_max_iterations_); + absl::make_unique( + expr, std::move(execution_path), comprehension_max_iterations_, + std::move(iter_variable_names), enable_unknowns_, + enable_unknown_function_results_); + if (warnings != nullptr) { + *warnings = std::move(warnings_builder).warnings(); + } return std::move(expression_impl); } +cel_base::StatusOr> +FlatExprBuilder::CreateExpression(const Expr* expr, + const SourceInfo* source_info) const { + return CreateExpression(expr, source_info, nullptr); +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 48beaee18..62dfb84b0 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -1,8 +1,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ -#include "eval/public/cel_expression.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "eval/public/cel_expression.h" namespace google { namespace api { @@ -14,11 +14,23 @@ namespace runtime { class FlatExprBuilder : public CelExpressionBuilder { public: FlatExprBuilder() - : shortcircuiting_(true), + : enable_unknowns_(false), + enable_unknown_function_results_(false), + shortcircuiting_(true), constant_folding_(false), constant_arena_(nullptr), enable_comprehension_(true), - comprehension_max_iterations_(0) {} + comprehension_max_iterations_(0), + fail_on_warnings_(true) {} + + // set_enable_unknowns controls support for unknowns in expressions created. + void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } + + // set_enable_unknown_function_results controls support for unknown function + // results. + void set_enable_unknown_function_results(bool enabled) { + enable_unknown_function_results_ = enabled; + } // set_shortcircuiting regulates shortcircuiting of some expressions. // Be default shortcircuiting is enabled. @@ -39,17 +51,30 @@ class FlatExprBuilder : public CelExpressionBuilder { comprehension_max_iterations_ = max_iterations; } + // Warnings (e.g. no function bound) fail immediately. + void set_fail_on_warnings(bool should_fail) { + fail_on_warnings_ = should_fail; + } + cel_base::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; + cel_base::StatusOr> CreateExpression( + const google::api::expr::v1alpha1::Expr* expr, + const google::api::expr::v1alpha1::SourceInfo* source_info, + std::vector* warnings) const override; + private: + bool enable_unknowns_; + bool enable_unknown_function_results_; bool shortcircuiting_; bool constant_folding_; google::protobuf::Arena* constant_arena_; bool enable_comprehension_; int comprehension_max_iterations_; + bool fail_on_warnings_; }; } // namespace runtime diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc new file mode 100644 index 000000000..331821853 --- /dev/null +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -0,0 +1,180 @@ +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/text_format.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_builtins.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" +#include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::api::expr::v1alpha1::Expr; +using google::api::expr::v1alpha1::SourceInfo; + +// [1, 2].filter(x, [3, 4].all(y, x < y)) +const char kNestedComprehension[] = R"pb( + id: 27 + comprehension_expr { + iter_var: "x" + iter_range { + id: 1 + list_expr { + elements { + id: 2 + const_expr { int64_value: 1 } + } + elements { + id: 3 + const_expr { int64_value: 2 } + } + } + } + accu_var: "__result__" + accu_init { + id: 22 + list_expr {} + } + loop_condition { + id: 23 + const_expr { bool_value: true } + } + loop_step { + id: 26 + call_expr { + function: "_?_:_" + args { + id: 20 + comprehension_expr { + iter_var: "y" + iter_range { + id: 6 + list_expr { + elements { + id: 7 + const_expr { int64_value: 3 } + } + elements { + id: 8 + const_expr { int64_value: 4 } + } + } + } + accu_var: "__result__" + accu_init { + id: 14 + const_expr { bool_value: true } + } + loop_condition { + id: 16 + call_expr { + function: "@not_strictly_false" + args { + id: 15 + ident_expr { name: "__result__" } + } + } + } + loop_step { + id: 18 + call_expr { + function: "_&&_" + args { + id: 17 + ident_expr { name: "__result__" } + } + args { + id: 12 + call_expr { + function: "_<_" + args { + id: 11 + ident_expr { name: "x" } + } + args { + id: 13 + ident_expr { name: "y" } + } + } + } + } + } + result { + id: 19 + ident_expr { name: "__result__" } + } + } + } + args { + id: 25 + call_expr { + function: "_+_" + args { + id: 21 + ident_expr { name: "__result__" } + } + args { + id: 24 + list_expr { + elements { + id: 5 + ident_expr { name: "x" } + } + } + } + } + } + args { + id: 21 + ident_expr { name: "__result__" } + } + } + } + result { + id: 21 + ident_expr { name: "__result__" } + } + })pb"; + +TEST(FlatExprBuilderComprehensionsTest, NestedComp) { + FlatExprBuilder builder; + Expr expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedComprehension, &expr)); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + SourceInfo source_info; + auto build_status = builder.CreateExpression(&expr, &source_info); + ASSERT_OK(build_status); + + auto cel_expr = std::move(build_status.ValueOrDie()); + + Activation activation; + google::protobuf::Arena arena; + auto result_or = cel_expr->Evaluate(activation, &arena); + ASSERT_OK(result_or); + CelValue result = result_or.ValueOrDie(); + ASSERT_TRUE(result.IsList()); + EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2b01bb22f..d9ddfd23a 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -5,10 +5,20 @@ #include "google/protobuf/text_format.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_builtins.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" + namespace google { namespace api { namespace expr { @@ -32,10 +42,10 @@ class ConcatFunction : public CelFunction { "concat", false, {CelValue::Type::kString, CelValue::Type::kString}}; } - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 2) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } @@ -47,7 +57,7 @@ class ConcatFunction : public CelFunction { *result = CelValue::CreateString(concatenated); - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -67,10 +77,10 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { auto register_status = builder.GetRegistry()->Register(absl::make_unique()); - ASSERT_TRUE(register_status.ok()); + ASSERT_OK(register_status); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -82,7 +92,7 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { google::protobuf::Arena arena; auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); @@ -91,20 +101,66 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { EXPECT_THAT(result.StringOrDie().value(), Eq("prefixtest")); } +TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { + Expr expr; + SourceInfo source_info; + auto call_expr = expr.mutable_call_expr(); + call_expr->set_function("concat"); + + auto arg1 = call_expr->add_args(); + arg1->mutable_const_expr()->set_string_value("prefix"); + + auto arg2 = call_expr->add_args(); + arg2->mutable_ident_expr()->set_name("value"); + + FlatExprBuilder builder; + builder.set_fail_on_warnings(false); + std::vector warnings; + + // Concat function not registered. + + auto build_status = builder.CreateExpression(&expr, &source_info, &warnings); + ASSERT_OK(build_status); + + auto cel_expr = std::move(build_status.ValueOrDie()); + + std::string variable = "test"; + + Activation activation; + activation.InsertValue("value", CelValue::CreateString(&variable)); + + google::protobuf::Arena arena; + + auto eval_status = cel_expr->Evaluate(activation, &arena); + ASSERT_OK(eval_status); + + CelValue result = eval_status.ValueOrDie(); + + ASSERT_TRUE(result.IsError()); + + EXPECT_THAT(result.ErrorOrDie()->message(), + Eq("No matching overloads found")); + + ASSERT_THAT(warnings, testing::SizeIs(1)); + EXPECT_EQ(warnings[0].code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(warnings[0].message(), + testing::HasSubstr("No overloads provided")); +} + class RecorderFunction : public CelFunction { public: explicit RecorderFunction(const std::string& name, int* count) : CelFunction(CelFunctionDescriptor{name, false, {}}), count_(count) {} - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (!args.empty()) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } (*count_)++; *result = CelValue::CreateBool(true); - return cel_base::OkStatus(); + return absl::OkStatus(); } int* count_; @@ -130,21 +186,21 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { auto register_status1 = builder.GetRegistry()->Register( absl::make_unique("recorder1", &count1)); - ASSERT_TRUE(register_status1.ok()); + ASSERT_OK(register_status1); auto register_status2 = builder.GetRegistry()->Register( absl::make_unique("recorder2", &count2)); - ASSERT_TRUE(register_status2.ok()); + ASSERT_OK(register_status2); // Shortcircuiting on. auto build_status_on = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_on.ok()); + ASSERT_OK(build_status_on); auto cel_expr_on = std::move(build_status_on.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto eval_status_on = cel_expr_on->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_on.ok()); + ASSERT_OK(eval_status_on); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(0)); @@ -152,7 +208,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { // Shortcircuiting off. builder.set_shortcircuiting(false); auto build_status_off = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_off.ok()); + ASSERT_OK(build_status_off); auto cel_expr_off = std::move(build_status_off.ValueOrDie()); @@ -160,7 +216,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { count2 = 0; auto eval_status_off = cel_expr_off->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_off.ok()); + ASSERT_OK(eval_status_off); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(1)); @@ -193,32 +249,32 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { int count = 0; auto register_status = builder.GetRegistry()->Register( absl::make_unique("loop_step", &count)); - ASSERT_TRUE(register_status.ok()); + ASSERT_OK(register_status); // Shortcircuiting on. auto build_status_on = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_on.ok()); + ASSERT_OK(build_status_on); auto cel_expr_on = std::move(build_status_on.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto eval_status_on = cel_expr_on->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_on.ok()); + ASSERT_OK(eval_status_on); EXPECT_THAT(count, Eq(0)); // Shortcircuiting off. builder.set_shortcircuiting(false); auto build_status_off = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_off.ok()); + ASSERT_OK(build_status_off); auto cel_expr_off = std::move(build_status_off.ValueOrDie()); count = 0; auto eval_status_off = cel_expr_off->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_off.ok()); + ASSERT_OK(eval_status_off); EXPECT_THAT(count, Eq(3)); } @@ -266,17 +322,17 @@ TEST(FlatExprBuilderTest, MapComprehension) { &expr); FlatExprBuilder builder; - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(result_or.ok()); + ASSERT_OK(result_or); CelValue result = result_or.ValueOrDie(); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); @@ -353,17 +409,17 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { &expr); FlatExprBuilder builder; - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(result_or.ok()); + ASSERT_OK(result_or); CelValue result = result_or.ValueOrDie(); ASSERT_TRUE(result.IsError()); } @@ -428,17 +484,17 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { &expr); FlatExprBuilder builder; - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(result_or.ok()); + ASSERT_OK(result_or); CelValue result = result_or.ValueOrDie(); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie()->message(), @@ -483,10 +539,10 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { FlatExprBuilder builder; builder.set_comprehension_max_iterations(1); - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -517,7 +573,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { FlatExprBuilder builder; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -529,7 +585,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsInt64()); @@ -539,7 +595,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { mask.add_paths("message.message_value.int32_value"); activation.set_unknown_paths(mask); eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); ASSERT_TRUE(IsUnknownValueError(result)); @@ -550,7 +606,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { mask.add_paths("message.message_value"); activation.set_unknown_paths(mask); eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); ASSERT_TRUE(IsUnknownValueError(result)); @@ -579,10 +635,10 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.addResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -590,7 +646,7 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { Activation activation; auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsInt64()); @@ -607,13 +663,13 @@ TEST(FlatExprBuilderTest, ContainerStringFormat) { builder.set_container(""); { auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); } builder.set_container("random.namespace"); { auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); } // Leading '.' @@ -650,19 +706,19 @@ void EvalExpressionWithEnum(absl::string_view enum_name, cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.addResolvableEnum(TestMessage::TestEnum_descriptor()); - builder.addResolvableEnum(TestEnum_descriptor()); + builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.AddResolvableEnum(TestEnum_descriptor()); builder.set_container(std::string(container)); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); google::protobuf::Arena arena; Activation activation; auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); *result = eval_status.ValueOrDie(); } @@ -720,6 +776,164 @@ TEST(FlatExprBuilderTest, PartialQualifiedEnumResolution) { EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } +absl::Status RunTernaryExpression(CelValue selector, CelValue value1, + CelValue value2, google::protobuf::Arena* arena, + CelValue* result) { + Expr expr; + SourceInfo source_info; + auto call_expr = expr.mutable_call_expr(); + call_expr->set_function(builtin::kTernary); + + auto arg0 = call_expr->add_args(); + arg0->mutable_ident_expr()->set_name("selector"); + auto arg1 = call_expr->add_args(); + arg1->mutable_ident_expr()->set_name("value1"); + auto arg2 = call_expr->add_args(); + arg2->mutable_ident_expr()->set_name("value2"); + + FlatExprBuilder builder; + auto build_status = builder.CreateExpression(&expr, &source_info); + if (!build_status.ok()) { + return build_status.status(); + } + + auto cel_expr = std::move(build_status.ValueOrDie()); + + std::string variable = "test"; + + Activation activation; + activation.InsertValue("selector", selector); + activation.InsertValue("value1", value1); + activation.InsertValue("value2", value2); + + auto eval_status = cel_expr->Evaluate(activation, arena); + if (!eval_status.ok()) { + return eval_status.status(); + } + + *result = eval_status.ValueOrDie(); + return eval_status.status(); +} + +TEST(FlatExprBuilderTest, Ternary) { + Expr expr; + SourceInfo source_info; + auto call_expr = expr.mutable_call_expr(); + call_expr->set_function(builtin::kTernary); + + auto arg0 = call_expr->add_args(); + arg0->mutable_ident_expr()->set_name("selector"); + auto arg1 = call_expr->add_args(); + arg1->mutable_ident_expr()->set_name("value1"); + auto arg2 = call_expr->add_args(); + arg2->mutable_ident_expr()->set_name("value1"); + + FlatExprBuilder builder; + // builder.set_enable_unknowns(true); + auto build_status = builder.CreateExpression(&expr, &source_info); + ASSERT_OK(build_status); + + auto cel_expr = std::move(build_status.ValueOrDie()); + + google::protobuf::Arena arena; + + // On True, value 1 + { + CelValue result; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(1)); + + // Unknown handling + UnknownSet unknown_set; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + + ASSERT_OK(RunTernaryExpression( + CelValue::CreateBool(true), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(1)); + } + + // On False, value 2 + { + CelValue result; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(2)); + + // Unknown handling + UnknownSet unknown_set; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(2)); + + ASSERT_OK(RunTernaryExpression( + CelValue::CreateBool(false), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + } + // On Error, surface error + { + CelValue result; + ASSERT_OK(RunTernaryExpression(CreateErrorValue(&arena, "error"), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsError()); + } + // On Unknown, surface Unknown + { + UnknownSet unknown_set; + CelValue result; + ASSERT_OK(RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + EXPECT_THAT(&unknown_set, Eq(result.UnknownSetOrDie())); + } + // We should not merge unknowns + { + Expr selector; + selector.mutable_ident_expr()->set_name("selector"); + CelAttribute selector_attr(selector, {}); + + Expr value1; + value1.mutable_ident_expr()->set_name("value1"); + CelAttribute value1_attr(value1, {}); + + Expr value2; + value2.mutable_ident_expr()->set_name("value2"); + CelAttribute value2_attr(value2, {}); + + UnknownSet unknown_selector(UnknownAttributeSet({&selector_attr})); + UnknownSet unknown_value1(UnknownAttributeSet({&value1_attr})); + UnknownSet unknown_value2(UnknownAttributeSet({&value2_attr})); + CelValue result; + ASSERT_OK(RunTernaryExpression( + CelValue::CreateUnknownSet(&unknown_selector), + CelValue::CreateUnknownSet(&unknown_value1), + CelValue::CreateUnknownSet(&unknown_value2), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + const UnknownSet* result_set = result.UnknownSetOrDie(); + EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(1)); + EXPECT_THAT(result_set->unknown_attributes() + .attributes()[0] + ->variable() + .ident_expr() + .name(), + Eq("selector")); + } +} + } // namespace } // namespace runtime diff --git a/eval/eval/BUILD b/eval/eval/BUILD index b66cd3c5b..f86f160ef 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -16,11 +16,19 @@ cc_library( ], copts = ["-std=c++14"], deps = [ + ":attribute_trail", + ":unknowns_utility", + "//base:status_macros", + "//base:statusor", "//eval/public:activation", + "//eval/public:cel_attribute", "//eval/public:cel_expression", "//eval/public:cel_value", - "@com_google_absl//absl/strings", + "//eval/public:unknown_attribute_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -74,9 +82,10 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", - "//base:status", "//eval/public:activation", "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -96,8 +105,9 @@ cc_library( ":expression_step_base", "//eval/public:activation", "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -112,14 +122,18 @@ cc_library( copts = ["-std=c++14"], deps = [ ":evaluator_core", + ":expression_build_warning", ":expression_step_base", - "//base:status", + "//base:status_macros", "//eval/public:activation", "//eval/public:cel_function", "//eval/public:cel_function_provider", "//eval/public:cel_function_registry", "//eval/public:cel_value", + "//eval/public:unknown_function_result_set", + "//eval/public:unknown_set", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -133,9 +147,9 @@ cc_library( ], copts = ["-std=c++14"], deps = [ - "//base:status", "//eval/public:cel_value", "//internal:proto_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -223,7 +237,6 @@ cc_library( "//eval/public:activation", "//eval/public:cel_value", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) @@ -262,7 +275,7 @@ cc_library( ":evaluator_core", ":expression_step_base", ":field_access", - "//base:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -303,6 +316,7 @@ cc_library( "//eval/public:activation", "//eval/public:cel_function", "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -320,6 +334,7 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", + "//base:status_macros", "//eval/public:activation", "//eval/public:cel_function", "//eval/public:cel_value", @@ -337,8 +352,11 @@ cc_test( copts = ["-std=c++14"], deps = [ ":evaluator_core", + "//base:status_macros", "//eval/compiler:flat_expr_builder", "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_value", "@com_github_google_googletest//:gtest_main", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -354,6 +372,7 @@ cc_test( deps = [ ":const_value_step", ":evaluator_core", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -371,6 +390,8 @@ cc_test( ":container_backed_list_impl", ":container_backed_map_impl", ":ident_step", + "//base:status_macros", + "//eval/public:cel_attribute", "//eval/public:cel_builtins", "//eval/public:cel_value", "@com_github_google_googletest//:gtest_main", @@ -388,6 +409,7 @@ cc_test( deps = [ ":evaluator_core", ":ident_step", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -402,9 +424,16 @@ cc_test( copts = ["-std=c++14"], deps = [ ":evaluator_core", + ":expression_build_warning", ":function_step", - "//base:status", + ":ident_step", + "//base:status_macros", + "//eval/public:cel_attribute", "//eval/public:cel_function", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_function_result_set", + "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -459,6 +488,23 @@ cc_test( ], ) +cc_test( + name = "logic_step_test", + size = "small", + srcs = [ + "logic_step_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":ident_step", + ":logic_step", + "//base:status_macros", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "@com_github_google_googletest//:gtest_main", + ], +) + cc_test( name = "select_step_test", size = "small", @@ -468,9 +514,11 @@ cc_test( copts = ["-std=c++14"], deps = [ ":container_backed_map_impl", - ":evaluator_core", ":ident_step", ":select_step", + "//base:status_macros", + "//eval/public:cel_attribute", + "//eval/public:unknown_attribute_set", "//eval/testutil:test_message_cc_proto", "//testutil:util", "@com_github_google_googletest//:gtest_main", @@ -488,10 +536,13 @@ cc_test( deps = [ ":const_value_step", ":create_list_step", - ":evaluator_core", - "//eval/testutil:test_message_cc_proto", + ":ident_step", + "//base:status_macros", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:unknown_attribute_set", "@com_github_google_googletest//:gtest_main", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/strings", ], ) @@ -507,6 +558,7 @@ cc_test( ":container_backed_map_impl", ":create_struct_step", ":ident_step", + "//base:status_macros", "//eval/testutil:test_message_cc_proto", "//testutil:util", "@com_github_google_googletest//:gtest_main", @@ -514,3 +566,105 @@ cc_test( "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) + +cc_library( + name = "expression_build_warning", + srcs = [ + "expression_build_warning.cc", + ], + hdrs = [ + "expression_build_warning.h", + ], + copts = ["-std=c++14"], + deps = [ + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "expression_build_warning_test", + size = "small", + srcs = [ + "expression_build_warning_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":expression_build_warning", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "attribute_trail", + srcs = ["attribute_trail.cc"], + hdrs = ["attribute_trail.h"], + copts = ["-std=c++14"], + deps = [ + "//base:statusor", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:cel_expression", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "attribute_trail_test", + size = "small", + srcs = [ + "attribute_trail_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":attribute_trail", + "//eval/public:cel_attribute", + "//eval/public:cel_value", + "@com_github_google_googletest//:gtest_main", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + ], +) + +cc_library( + name = "unknowns_utility", + srcs = ["unknowns_utility.cc"], + hdrs = ["unknowns_utility.h"], + copts = ["-std=c++14"], + deps = [ + ":attribute_trail", + "//base:statusor", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:cel_expression", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "unknowns_utility_test", + size = "small", + srcs = [ + "unknowns_utility_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":unknowns_utility", + "//eval/public:cel_attribute", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "@com_github_google_googletest//:gtest_main", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + ], +) diff --git a/eval/eval/attribute_trail.cc b/eval/eval/attribute_trail.cc new file mode 100644 index 000000000..ec07728ee --- /dev/null +++ b/eval/eval/attribute_trail.cc @@ -0,0 +1,25 @@ +#include "eval/eval/attribute_trail.h" + +#include "absl/status/status.h" +#include "eval/public/cel_value.h" +#include "base/statusor.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +// Creates AttributeTrail with attribute path incremented by "qualifier". +AttributeTrail AttributeTrail::Step(CelAttributeQualifier qualifier, + google::protobuf::Arena* arena) const { + // Cannot continue void trail + if (empty()) return AttributeTrail(); + + std::vector qualifiers = attribute_->qualifier_path(); + qualifiers.push_back(qualifier); + return AttributeTrail(google::protobuf::Arena::Create( + arena, attribute_->variable(), std::move(qualifiers))); +} +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h new file mode 100644 index 000000000..eb18ac0d5 --- /dev/null +++ b/eval/eval/attribute_trail.h @@ -0,0 +1,63 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ + +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "absl/types/optional.h" +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// AttributeTrail reflects current attribute path. +// It is functionally similar to CelAttribute, yet intended to have better +// complexity on attribute path increment operations. +// TODO(issues/41) Current AttributeTrail implementation is equivalent to +// CelAttribute - improve it. +// Intended to be used in conjunction with CelValue, describing the attribute +// value originated from. +// Empty AttributeTrail denotes object with attribute path not defined +// or supported. +class AttributeTrail { + public: + AttributeTrail() : attribute_(nullptr) {} + AttributeTrail(google::api::expr::v1alpha1::Expr root, google::protobuf::Arena* arena) + : AttributeTrail(google::protobuf::Arena::Create( + arena, root, std::vector())) {} + + // Creates AttributeTrail with attribute path incremented by "qualifier". + AttributeTrail Step(CelAttributeQualifier qualifier, + google::protobuf::Arena* arena) const; + + // Creates AttributeTrail with attribute path incremented by "qualifier". + AttributeTrail Step(const std::string* qualifier, + google::protobuf::Arena* arena) const { + return Step( + CelAttributeQualifier::Create(CelValue::CreateString(qualifier)), + arena); + } + + // Returns CelAttribute that corresponds to content of AttributeTrail. + const CelAttribute* attribute() const { return attribute_; } + + bool empty() const { return !attribute_; } + + private: + AttributeTrail(const CelAttribute* attribute) : attribute_(attribute) {} + const CelAttribute* attribute_; +}; +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ diff --git a/eval/eval/attribute_trail_test.cc b/eval/eval/attribute_trail_test.cc new file mode 100644 index 000000000..ccf0d36c8 --- /dev/null +++ b/eval/eval/attribute_trail_test.cc @@ -0,0 +1,41 @@ +#include "eval/eval/attribute_trail.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Attribute Trail behavior +TEST(AttributeTrailTest, AttributeTrailEmptyStep) { + google::protobuf::Arena arena; + std::string step = "step"; + CelValue step_value = CelValue::CreateString(&step); + AttributeTrail trail; + ASSERT_TRUE(trail.Step(&step, &arena).empty()); + ASSERT_TRUE( + trail.Step(CelAttributeQualifier::Create(step_value), &arena).empty()); +} + +TEST(AttributeTrailTest, AttributeTrailStep) { + google::protobuf::Arena arena; + std::string step = "step"; + CelValue step_value = CelValue::CreateString(&step); + google::api::expr::v1alpha1::Expr root; + root.mutable_ident_expr()->set_name("ident"); + AttributeTrail trail = AttributeTrail(root, &arena).Step(&step, &arena); + + ASSERT_TRUE(trail.attribute() != nullptr); + ASSERT_EQ(*trail.attribute(), + CelAttribute(root, {CelAttributeQualifier::Create(step_value)})); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index e835149cf..afe88984a 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -1,5 +1,7 @@ #include "eval/eval/comprehension_step.h" + #include "absl/strings/str_cat.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -66,7 +68,7 @@ void ComprehensionNextStep::set_error_jump_offset(int offset) { // // Stack on error: // 0. error -cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { enum { POS_PREVIOUS_LOOP_STEP, POS_ITER_RANGE, @@ -75,7 +77,7 @@ cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { POS_LOOP_STEP, }; if (!frame->value_stack().HasEnough(5)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } auto state = frame->value_stack().GetSpan(5); CelValue iter_range = state[POS_ITER_RANGE]; @@ -95,19 +97,22 @@ cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { "ComprehensionNextStep: want int64_t, got ", CelValue::TypeName(current_index_value.type()) ); - return cel_base::Status(cel_base::StatusCode::kInternal, message); + return absl::Status(absl::StatusCode::kInternal, message); } auto increment_status = frame->IncrementIterations(); if (!increment_status.ok()) { return increment_status; } int64_t current_index = current_index_value.Int64OrDie(); + if (current_index == -1) { + RETURN_IF_ERROR(frame->PushIterFrame()); + } CelValue loop_step = state[POS_LOOP_STEP]; frame->value_stack().Pop(5); frame->value_stack().Push(loop_step); - frame->iter_vars()[accu_var_] = loop_step; + RETURN_IF_ERROR(frame->SetIterVar(accu_var_, loop_step)); if (current_index >= cel_list->size() - 1) { - frame->iter_vars().erase(iter_var_); + RETURN_IF_ERROR(frame->ClearIterVar(iter_var_)); return frame->JumpTo(jump_offset_); } frame->value_stack().Push(iter_range); @@ -115,8 +120,8 @@ cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { CelValue current_value = (*cel_list)[current_index]; frame->value_stack().Push(CelValue::CreateInt64(current_index)); frame->value_stack().Push(current_value); - frame->iter_vars()[iter_var_] = current_value; - return cel_base::OkStatus(); + RETURN_IF_ERROR(frame->SetIterVar(iter_var_, current_value)); + return absl::OkStatus(); } ComprehensionCondStep::ComprehensionCondStep(const std::string&, @@ -136,9 +141,9 @@ void ComprehensionCondStep::set_jump_offset(int offset) { // Stack size before: 5. // Stack size after: 4. // Stack size on break: 1. -cel_base::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(5)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue loop_condition_value = frame->value_stack().Peek(); if (!loop_condition_value.IsBool()) { @@ -146,16 +151,16 @@ cel_base::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { "ComprehensionCondStep:: want bool, got ", CelValue::TypeName(loop_condition_value.type()) ); - return cel_base::Status(cel_base::StatusCode::kInternal, message); + return absl::Status(absl::StatusCode::kInternal, message); } bool loop_condition = loop_condition_value.BoolOrDie(); frame->value_stack().Pop(1); // loop_condition if (!loop_condition && shortcircuiting_) { frame->value_stack().Pop(3); // current_value, current_index, iter_range - frame->iter_vars().erase(iter_var_); + RETURN_IF_ERROR(frame->ClearIterVar(iter_var_)); return frame->JumpTo(jump_offset_); } - return cel_base::OkStatus(); + return absl::OkStatus(); } ComprehensionFinish::ComprehensionFinish(const std::string& accu_var, const std::string&, @@ -166,38 +171,38 @@ ComprehensionFinish::ComprehensionFinish(const std::string& accu_var, const std: // // Stack size before: 2. // Stack size after: 1. -cel_base::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(2)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue result = frame->value_stack().Peek(); frame->value_stack().Pop(1); // result frame->value_stack().PopAndPush(result); - frame->iter_vars().erase(accu_var_); - return cel_base::OkStatus(); + RETURN_IF_ERROR(frame->PopIterFrame()); + return absl::OkStatus(); } class ListKeysStep : public ExpressionStepBase { public: ListKeysStep(int64_t expr_id) : ExpressionStepBase(expr_id, false) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; }; std::unique_ptr CreateListKeysStep(int64_t expr_id) { return absl::make_unique(expr_id); } -cel_base::Status ListKeysStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ListKeysStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue map_value = frame->value_stack().Peek(); if (map_value.IsMap()) { const CelMap* cel_map = map_value.MapOrDie(); frame->value_stack().PopAndPush(CelValue::CreateList(cel_map->ListKeys())); - return cel_base::OkStatus(); + return absl::OkStatus(); } - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/eval/comprehension_step.h b/eval/eval/comprehension_step.h index 9a0c3146d..1fa9ee6bf 100644 --- a/eval/eval/comprehension_step.h +++ b/eval/eval/comprehension_step.h @@ -21,7 +21,7 @@ class ComprehensionNextStep : public ExpressionStepBase { void set_jump_offset(int offset); void set_error_jump_offset(int offset); - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string accu_var_; @@ -37,7 +37,7 @@ class ComprehensionCondStep : public ExpressionStepBase { void set_jump_offset(int offset); - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string iter_var_; @@ -50,7 +50,7 @@ class ComprehensionFinish : public ExpressionStepBase { ComprehensionFinish(const std::string& accu_var, const std::string& iter_var, int64_t expr_id); - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string accu_var_; diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 465e40132..ee070a302 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -18,16 +18,16 @@ class ConstValueStep : public ExpressionStepBase { ConstValueStep(const CelValue& value, int64_t expr_id, bool comes_from_ast) : ExpressionStepBase(expr_id, comes_from_ast), value_(value) {} - cel_base::Status Evaluate(ExecutionFrame* context) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: CelValue value_; }; -cel_base::Status ConstValueStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ConstValueStep::Evaluate(ExecutionFrame* frame) const { frame->value_stack().Push(value_); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index d2da6a773..51acef959 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -1,9 +1,10 @@ #include "eval/eval/const_value_step.h" -#include "eval/eval/evaluator_core.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/eval/evaluator_core.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -14,8 +15,8 @@ namespace { using testing::Eq; -using google::api::expr::v1alpha1::Expr; using google::api::expr::v1alpha1::Constant; +using google::api::expr::v1alpha1::Expr; using google::protobuf::Arena; @@ -31,7 +32,7 @@ cel_base::StatusOr RunConstantExpression(const Expr* expr, google::api::expr::v1alpha1::Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}); Activation activation; @@ -47,7 +48,7 @@ TEST(ConstValueStepTest, TestEvaluationConstInt64) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -64,7 +65,7 @@ TEST(ConstValueStepTest, TestEvaluationConstUint64) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -81,7 +82,7 @@ TEST(ConstValueStepTest, TestEvaluationConstBool) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -98,7 +99,7 @@ TEST(ConstValueStepTest, TestEvaluationConstNull) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -114,7 +115,7 @@ TEST(ConstValueStepTest, TestEvaluationConstString) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -131,7 +132,7 @@ TEST(ConstValueStepTest, TestEvaluationConstDouble) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -150,7 +151,7 @@ TEST(ConstValueStepTest, TestEvaluationConstBytes) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 15246a18a..37e4ba731 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -1,10 +1,12 @@ #include "eval/eval/container_access_step.h" #include "google/protobuf/arena.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/public/cel_value.h" -#include "base/status.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -13,17 +15,20 @@ namespace runtime { namespace { +constexpr int NUM_CONTAINER_ACCESS_ARGUMENTS = 2; + // ContainerAccessStep performs message field access specified by Expr::Select // message. class ContainerAccessStep : public ExpressionStepBase { public: ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - CelValue PerformLookup(const CelValue& container, const CelValue& key, - google::protobuf::Arena* arena) const; + using ValueAttributePair = std::pair; + + ValueAttributePair PerformLookup(ExecutionFrame* frame) const; CelValue LookupInMap(const CelMap* cel_map, const CelValue& key, google::protobuf::Arena* arena) const; CelValue LookupInList(const CelList* cel_list, const CelValue& key, @@ -72,52 +77,78 @@ inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, } } -CelValue ContainerAccessStep::PerformLookup(const CelValue& container, - const CelValue& key, - google::protobuf::Arena* arena) const { - if (container.IsError()) { - return container; +ContainerAccessStep::ValueAttributePair ContainerAccessStep::PerformLookup( + ExecutionFrame* frame) const { + auto input_args = + frame->value_stack().GetSpan(NUM_CONTAINER_ACCESS_ARGUMENTS); + AttributeTrail trail; + + const CelValue& container = input_args[0]; + const CelValue& key = input_args[1]; + + if (frame->enable_unknowns()) { + auto unknown_set = + frame->unknowns_utility().MergeUnknowns(input_args, nullptr); + + if (unknown_set) { + return {CelValue::CreateUnknownSet(unknown_set), trail}; + } + + // We guarantee that GetAttributeSpan can aquire this number of arguments + // by calling HasEnough() at the beginning of Execute() method. + auto input_attrs = + frame->value_stack().GetAttributeSpan(NUM_CONTAINER_ACCESS_ARGUMENTS); + auto container_trail = input_attrs[0]; + trail = container_trail.Step(CelAttributeQualifier::Create(key), + frame->arena()); + + if (frame->unknowns_utility().CheckForUnknown(trail, + /*use_partial=*/false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({trail.attribute()})); + + return {CelValue::CreateUnknownSet(unknown_set), trail}; + } } - if (key.IsError()) { - return key; + + for (const auto& value : input_args) { + if (value.IsError()) { + return {value, trail}; + } } + // Select steps can be applied to either maps or messages switch (container.type()) { case CelValue::Type::kMap: { const CelMap* cel_map = container.MapOrDie(); - return LookupInMap(cel_map, key, arena); + return {LookupInMap(cel_map, key, frame->arena()), trail}; } case CelValue::Type::kList: { const CelList* cel_list = container.ListOrDie(); - return LookupInList(cel_list, key, arena); + return {LookupInList(cel_list, key, frame->arena()), trail}; } default: { - return CreateErrorValue( - arena, absl::StrCat("Unexpected container type for [] operation: ", - CelValue::TypeName(key.type()))); + auto error = CreateErrorValue( + frame->arena(), + absl::StrCat("Unexpected container type for [] operation: ", + CelValue::TypeName(key.type()))); + return {error, trail}; } } } -cel_base::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { - const int NUM_ARGUMENTS = 2; - - if (!frame->value_stack().HasEnough(NUM_ARGUMENTS)) { - return cel_base::Status( - cel_base::StatusCode::kInternal, +absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(NUM_CONTAINER_ACCESS_ARGUMENTS)) { + return absl::Status( + absl::StatusCode::kInternal, "Insufficient arguments supplied for ContainerAccess-type expression"); } - auto input_args = frame->value_stack().GetSpan(NUM_ARGUMENTS); - - const CelValue& container = input_args[0]; - const CelValue& key = input_args[1]; - - CelValue result = PerformLookup(container, key, frame->arena()); - frame->value_stack().Pop(NUM_ARGUMENTS); - frame->value_stack().Push(result); + auto result = PerformLookup(frame); + frame->value_stack().Pop(NUM_CONTAINER_ACCESS_ARGUMENTS); + frame->value_stack().Push(result.first, result.second); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 22a13fd62..d3b507310 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -10,8 +10,10 @@ #include "eval/eval/container_backed_list_impl.h" #include "eval/eval/container_backed_map_impl.h" #include "eval/eval/ident_step.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_value.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -25,95 +27,140 @@ using ::google::protobuf::Struct; using google::api::expr::v1alpha1::Expr; using google::api::expr::v1alpha1::SourceInfo; -class ContainerAccessStepTest : public ::testing::Test { - protected: - ContainerAccessStepTest() {} +using TestParamType = std::tuple; - void SetUp() override {} +// Helper method. Looks up in registry and tests comparison operation. +CelValue EvaluateAttributeHelper( + google::protobuf::Arena* arena, CelValue container, CelValue key, bool receiver_style, + bool enable_unknown, const std::vector& patterns) { + ExecutionPath path; - // Helper method. Looks up in registry and tests comparison operation. - CelValue PerformRun(CelValue container, CelValue key, bool receiver_style) { - ExecutionPath path; + Expr expr; + SourceInfo source_info; + auto call = expr.mutable_call_expr(); + + call->set_function(builtin::kIndex); + + Expr* container_expr = + (receiver_style) ? call->mutable_target() : call->add_args(); + Expr* key_expr = call->add_args(); + + container_expr->mutable_ident_expr()->set_name("container"); + key_expr->mutable_ident_expr()->set_name("key"); - Expr expr; - SourceInfo source_info; - auto call = expr.mutable_call_expr(); + path.push_back(std::move( + CreateIdentStep(&container_expr->ident_expr(), 1).ValueOrDie())); + path.push_back( + std::move(CreateIdentStep(&key_expr->ident_expr(), 2).ValueOrDie())); + path.push_back(std::move(CreateContainerAccessStep(call, 3).ValueOrDie())); - call->set_function(builtin::kIndex); + CelExpressionFlatImpl cel_expr(&expr, std::move(path), 0, {}, enable_unknown); + Activation activation; - Expr* container_expr = - (receiver_style) ? call->mutable_target() : call->add_args(); - Expr* key_expr = call->add_args(); + activation.InsertValue("container", container); + activation.InsertValue("key", key); + + activation.set_unknown_attribute_patterns(patterns); + auto eval_status = cel_expr.Evaluate(activation, arena); + + EXPECT_OK(eval_status); + return eval_status.ValueOrDie(); +} + +class ContainerAccessStepTest : public ::testing::Test { + protected: + ContainerAccessStepTest() {} - container_expr->mutable_ident_expr()->set_name("container"); - key_expr->mutable_ident_expr()->set_name("key"); + void SetUp() override {} - path.push_back(std::move( - CreateIdentStep(&container_expr->ident_expr(), 1).ValueOrDie())); - path.push_back( - std::move(CreateIdentStep(&key_expr->ident_expr(), 2).ValueOrDie())); - path.push_back(std::move(CreateContainerAccessStep(call, 3).ValueOrDie())); + CelValue EvaluateAttribute( + CelValue container, CelValue key, bool receiver_style, + bool enable_unknown, + const std::vector& patterns = {}) { + return EvaluateAttributeHelper(&arena_, container, key, receiver_style, + enable_unknown, patterns); + } + google::protobuf::Arena arena_; +}; - CelExpressionFlatImpl cel_expr(&expr, std::move(path), 0); - Activation activation; +class ContainerAccessStepUniformityTest + : public ::testing::TestWithParam { + protected: + ContainerAccessStepUniformityTest() {} - activation.InsertValue("container", container); - activation.InsertValue("key", key); - auto eval_status = cel_expr.Evaluate(activation, &arena_); + void SetUp() override {} - EXPECT_TRUE(eval_status.ok()); - return eval_status.ValueOrDie(); + // Helper method. Looks up in registry and tests comparison operation. + CelValue EvaluateAttribute( + CelValue container, CelValue key, bool receiver_style, + bool enable_unknown, + const std::vector& patterns = {}) { + return EvaluateAttributeHelper(&arena_, container, key, receiver_style, + enable_unknown, patterns); } google::protobuf::Arena arena_; }; -TEST_F(ContainerAccessStepTest, TestListIndexAccess) { +TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccess) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - CelValue result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(1), true); + TestParamType param = GetParam(); + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), + std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), 2); } -TEST_F(ContainerAccessStepTest, TestListIndexAccessOutOfBounds) { +TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessOutOfBounds) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - CelValue result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(0), true); + TestParamType param = GetParam(); + + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(0), + std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsInt64()); - result = PerformRun(CelValue::CreateList(&cel_list), CelValue::CreateInt64(2), - true); + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(2), std::get<0>(param), + std::get<1>(param)); ASSERT_TRUE(result.IsInt64()); - result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(-1), true); + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(-1), std::get<0>(param), + std::get<1>(param)); ASSERT_TRUE(result.IsError()); - result = PerformRun(CelValue::CreateList(&cel_list), CelValue::CreateInt64(3), - true); + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(3), std::get<0>(param), + std::get<1>(param)); ASSERT_TRUE(result.IsError()); } -TEST_F(ContainerAccessStepTest, TestListIndexAccessNotAnInt) { +TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessNotAnInt) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - CelValue result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateUint64(1), true); + TestParamType param = GetParam(); + + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateUint64(1), + std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsError()); } -TEST_F(ContainerAccessStepTest, TestMapKeyAccess) { +TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccess) { + TestParamType param = GetParam(); + const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; const std::string kKey2 = "testkey2"; @@ -122,25 +169,92 @@ TEST_F(ContainerAccessStepTest, TestMapKeyAccess) { (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); - CelValue result = PerformRun(CelValue::CreateMessage(&cel_struct, &arena_), - CelValue::CreateString(&kKey0), true); + CelValue result = EvaluateAttribute( + CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateString(&kKey0), std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsString()); ASSERT_EQ(result.StringOrDie().value(), "value0"); } -TEST_F(ContainerAccessStepTest, TestMapKeyAccessNotFound) { +TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { + TestParamType param = GetParam(); + const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; Struct cel_struct; (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); - CelValue result = PerformRun(CelValue::CreateMessage(&cel_struct, &arena_), - CelValue::CreateString(&kKey1), true); + CelValue result = EvaluateAttribute( + CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateString(&kKey1), std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsError()); } +TEST_F(ContainerAccessStepTest, TestListIndexAccessUnknown) { + ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3)}); + + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), true, true, {}); + + ASSERT_TRUE(result.IsInt64()); + ASSERT_EQ(result.Int64OrDie(), 2); + + std::vector patterns = {CelAttributePattern( + "container", + {CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1))})}; + + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), true, true, patterns); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST_F(ContainerAccessStepTest, TestListUnknownKey) { + ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3)}); + + UnknownSet unknown_set; + CelValue result = + EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateUnknownSet(&unknown_set), true, true); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST_F(ContainerAccessStepTest, TestMapUnknownKey) { + const std::string kKey0 = "testkey0"; + const std::string kKey1 = "testkey1"; + const std::string kKey2 = "testkey2"; + Struct cel_struct; + (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); + (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); + (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); + + UnknownSet unknown_set; + CelValue result = + EvaluateAttribute(CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateUnknownSet(&unknown_set), true, true); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST_F(ContainerAccessStepTest, TestUnknownContainer) { + UnknownSet unknown_set; + CelValue result = EvaluateAttribute(CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(1), true, true); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +INSTANTIATE_TEST_SUITE_P(CombinedContainerTest, + ContainerAccessStepUniformityTest, + testing::Combine(testing::Bool(), testing::Bool())); + } // namespace } // namespace runtime diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index d2b7fe66c..27cc6654e 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -1,4 +1,5 @@ #include "eval/eval/create_list_step.h" + #include "eval/eval/container_backed_list_impl.h" namespace google { @@ -13,32 +14,48 @@ class CreateListStep : public ExpressionStepBase { CreateListStep(int64_t expr_id, int list_size) : ExpressionStepBase(expr_id), list_size_(list_size) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: int list_size_; }; -cel_base::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { +absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { if (list_size_ < 0) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "CreateListStep: list size is <0"); } if (!frame->value_stack().HasEnough(list_size_)) { - return cel_base::Status(cel_base::StatusCode::kInternal, - "CreateListStep: stack undeflow"); + return absl::Status(absl::StatusCode::kInternal, + "CreateListStep: stack underflow"); } auto args = frame->value_stack().GetSpan(list_size_); - CelList* cel_list = google::protobuf::Arena::Create( - frame->arena(), std::vector(args.begin(), args.end())); + CelValue result; + + const UnknownSet* unknown_set = nullptr; + if (frame->enable_unknowns()) { + unknown_set = frame->unknowns_utility().MergeUnknowns( + args, frame->value_stack().GetAttributeSpan(list_size_), + /*initial_set=*/nullptr, + /*use_partial=*/true); + if (unknown_set != nullptr) { + result = CelValue::CreateUnknownSet(unknown_set); + } + } + + if (unknown_set == nullptr) { + CelList* cel_list = google::protobuf::Arena::Create( + frame->arena(), std::vector(args.begin(), args.end())); + result = CelValue::CreateList(cel_list); + } frame->value_stack().Pop(list_size_); - frame->value_stack().Push(CelValue::CreateList(cel_list)); + frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index fc796b264..cef1737a4 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -1,8 +1,14 @@ #include "eval/eval/create_list_step.h" -#include "eval/eval/const_value_step.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/ident_step.h" +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/unknown_attribute_set.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -17,7 +23,8 @@ using google::api::expr::v1alpha1::Expr; // Helper method. Creates simple pipeline containing Select step and runs it. cel_base::StatusOr RunExpression(const std::vector& values, - google::protobuf::Arena* arena) { + google::protobuf::Arena* arena, + bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; @@ -42,12 +49,54 @@ cel_base::StatusOr RunExpression(const std::vector& values, path.push_back(std::move(step0_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + enable_unknowns); Activation activation; return cel_expr.Evaluate(activation, arena); } +// Helper method. Creates simple pipeline containing Select step and runs it. +cel_base::StatusOr RunExpressionWithCelValues( + const std::vector& values, google::protobuf::Arena* arena, + bool enable_unknowns) { + ExecutionPath path; + Expr dummy_expr; + + Activation activation; + auto create_list = dummy_expr.mutable_list_expr(); + int ind = 0; + for (auto value : values) { + std::string var_name = absl::StrCat("name_", ind++); + auto expr0 = create_list->add_elements(); + expr0->set_id(ind); + expr0->mutable_ident_expr()->set_name(var_name); + + auto ident_step_status = CreateIdentStep(&expr0->ident_expr(), expr0->id()); + if (!ident_step_status.ok()) { + return ident_step_status.status(); + } + + path.push_back(std::move(ident_step_status.ValueOrDie())); + activation.InsertValue(var_name, value); + } + + auto step0_status = CreateCreateListStep(create_list, dummy_expr.id()); + + if (!step0_status.ok()) { + return step0_status.status(); + } + + path.push_back(std::move(step0_status.ValueOrDie())); + + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + enable_unknowns); + + return cel_expr.Evaluate(activation, arena); +} + +class CreateListStepTest : public testing::TestWithParam {}; + // Tests error when not enough list elements are on the stack during list // creation. TEST(CreateListStepTest, TestCreateListStackUndeflow) { @@ -60,11 +109,11 @@ TEST(CreateListStepTest, TestCreateListStackUndeflow) { auto step0_status = CreateCreateListStep(create_list, dummy_expr.id()); - ASSERT_TRUE(step0_status.ok()); + ASSERT_OK(step0_status); path.push_back(std::move(step0_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}); Activation activation; google::protobuf::Arena arena; @@ -73,36 +122,36 @@ TEST(CreateListStepTest, TestCreateListStackUndeflow) { ASSERT_FALSE(status.ok()); } -TEST(CreateListStepTest, CreateListEmpty) { +TEST_P(CreateListStepTest, CreateListEmpty) { google::protobuf::Arena arena; - auto eval_result = RunExpression({}, &arena); + auto eval_result = RunExpression({}, &arena, GetParam()); - ASSERT_TRUE(eval_result.ok()); + ASSERT_OK(eval_result); const CelValue result_value = eval_result.ValueOrDie(); ASSERT_TRUE(result_value.IsList()); EXPECT_THAT(result_value.ListOrDie()->size(), Eq(0)); } -TEST(CreateListStepTest, CreateListOne) { +TEST_P(CreateListStepTest, CreateListOne) { google::protobuf::Arena arena; - auto eval_result = RunExpression({100}, &arena); + auto eval_result = RunExpression({100}, &arena, GetParam()); - ASSERT_TRUE(eval_result.ok()); + ASSERT_OK(eval_result); const CelValue result_value = eval_result.ValueOrDie(); ASSERT_TRUE(result_value.IsList()); EXPECT_THAT(result_value.ListOrDie()->size(), Eq(1)); EXPECT_THAT((*result_value.ListOrDie())[0].Int64OrDie(), Eq(100)); } -TEST(CreateListStepTest, CreateListHundred) { +TEST_P(CreateListStepTest, CreateListHundred) { google::protobuf::Arena arena; std::vector values; for (size_t i = 0; i < 100; i++) { values.push_back(i); } - auto eval_result = RunExpression(values, &arena); + auto eval_result = RunExpression(values, &arena, GetParam()); - ASSERT_TRUE(eval_result.ok()); + ASSERT_OK(eval_result); const CelValue result_value = eval_result.ValueOrDie(); ASSERT_TRUE(result_value.IsList()); EXPECT_THAT(result_value.ListOrDie()->size(), @@ -112,6 +161,36 @@ TEST(CreateListStepTest, CreateListHundred) { } } +TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { + google::protobuf::Arena arena; + std::vector values; + + Expr expr0; + expr0.mutable_ident_expr()->set_name("name0"); + CelAttribute attr0(expr0, {}); + Expr expr1; + expr1.mutable_ident_expr()->set_name("name1"); + CelAttribute attr1(expr1, {}); + UnknownSet unknown_set0(UnknownAttributeSet({&attr0})); + UnknownSet unknown_set1(UnknownAttributeSet({&attr1})); + for (size_t i = 0; i < 100; i++) { + values.push_back(CelValue::CreateInt64(i)); + } + values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); + values.push_back(CelValue::CreateUnknownSet(&unknown_set1)); + + auto eval_result = RunExpressionWithCelValues(values, &arena, true); + + ASSERT_OK(eval_result); + const CelValue result_value = eval_result.ValueOrDie(); + ASSERT_TRUE(result_value.IsUnknownSet()); + const UnknownSet* result_set = result_value.UnknownSetOrDie(); + EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(2)); +} + +INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, + testing::Bool()); + } // namespace } // namespace runtime diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index c417b9a32..7bba18c50 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -1,10 +1,10 @@ #include "eval/eval/create_struct_step.h" -#include "eval/eval/container_backed_map_impl.h" -#include "eval/eval/field_access.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/status/status.h" #include "absl/strings/substitute.h" -#include "base/canonical_errors.h" +#include "eval/eval/container_backed_map_impl.h" +#include "eval/eval/field_access.h" namespace google { namespace api { @@ -13,12 +13,12 @@ namespace runtime { namespace { -using ::google::protobuf::Message; -using ::google::protobuf::MessageFactory; using ::google::protobuf::Arena; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorPool; using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::MessageFactory; class CreateStructStepForMessage : public ExpressionStepBase { public: @@ -32,10 +32,10 @@ class CreateStructStepForMessage : public ExpressionStepBase { descriptor_(descriptor), entries_(std::move(entries)) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; const Descriptor* descriptor_; std::vector entries_; @@ -46,20 +46,31 @@ class CreateStructStepForMap : public ExpressionStepBase { CreateStructStepForMap(int64_t expr_id, size_t entry_count) : ExpressionStepBase(expr_id), entry_count_(entry_count) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; size_t entry_count_; }; -::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, - CelValue* result) const { +absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, + CelValue* result) const { int entries_size = entries_.size(); absl::Span args = frame->value_stack().GetSpan(entries_size); + if (frame->enable_unknowns()) { + auto unknown_set = frame->unknowns_utility().MergeUnknowns( + args, frame->value_stack().GetAttributeSpan(entries_size), + /*initial_set=*/nullptr, + /*use_partial=*/true); + if (unknown_set != nullptr) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + } + const Message* prototype = MessageFactory::generated_factory()->GetPrototype(descriptor_); @@ -70,14 +81,14 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, *result = CreateErrorValue( frame->arena(), absl::Substitute("Failed to create message $0", descriptor_->name())); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } int index = 0; for (const auto& entry : entries_) { const CelValue& arg = args[index++]; - ::cel_base::Status status = ::cel_base::OkStatus(); + absl::Status status = absl::OkStatus(); if (entry.field->is_map()) { constexpr int kKeyField = 1; @@ -85,7 +96,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, const CelMap* cel_map; if (!arg.GetValue(&cel_map) || cel_map == nullptr) { - status = cel_base::InvalidArgumentError(absl::Substitute( + status = absl::InvalidArgumentError(absl::Substitute( "Failed to create message $0, field $1: value is not CelMap", descriptor_->name(), entry.field->name())); break; @@ -94,7 +105,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, auto entry_descriptor = entry.field->message_type(); if (entry_descriptor == nullptr) { - status = cel_base::InvalidArgumentError( + status = absl::InvalidArgumentError( absl::Substitute("Failed to create message $0, field $1: failed to " "find map entry descriptor", descriptor_->name(), entry.field->name())); @@ -107,14 +118,14 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, entry_descriptor->FindFieldByNumber(kValueField); if (key_field_descriptor == nullptr) { - status = cel_base::InvalidArgumentError( + status = absl::InvalidArgumentError( absl::Substitute("Failed to create message $0, field $1: failed to " "find key field descriptor", descriptor_->name(), entry.field->name())); break; } if (value_field_descriptor == nullptr) { - status = cel_base::InvalidArgumentError( + status = absl::InvalidArgumentError( absl::Substitute("Failed to create message $0, field $1: failed to " "find value field descriptor", descriptor_->name(), entry.field->name())); @@ -127,7 +138,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, auto value = (*cel_map)[key]; if (!value.has_value()) { - status = cel_base::InvalidArgumentError(absl::Substitute( + status = absl::InvalidArgumentError(absl::Substitute( "Failed to create message $0, field $1: Error serializing CelMap", descriptor_->name(), entry.field->name())); break; @@ -153,7 +164,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, absl::Substitute( "Failed to create message $0: value $1 is not CelList", descriptor_->name(), entry.field->name())); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } for (int i = 0; i < cel_list->size(); i++) { @@ -169,25 +180,24 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, frame->arena(), absl::Substitute("Failed to create message $0: reason $1", descriptor_->name(), status.ToString())); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } *result = CelValue::CreateMessage(msg, frame->arena()); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } -::cel_base::Status CreateStructStepForMessage::Evaluate( - ExecutionFrame* frame) const { +absl::Status CreateStructStepForMessage::Evaluate(ExecutionFrame* frame) const { if (frame->value_stack().size() < entries_.size()) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "CreateStructStepForMessage: stack undeflow"); } CelValue result; - ::cel_base::Status status = DoEvaluate(frame, &result); + absl::Status status = DoEvaluate(frame, &result); if (!status.ok()) { return status; } @@ -195,14 +205,24 @@ ::cel_base::Status CreateStructStepForMessage::Evaluate( frame->value_stack().Pop(entries_.size()); frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } -::cel_base::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, - CelValue* result) const { +absl::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, + CelValue* result) const { absl::Span args = frame->value_stack().GetSpan(2 * entry_count_); + if (frame->enable_unknowns()) { + const UnknownSet* unknown_set = frame->unknowns_utility().MergeUnknowns( + args, frame->value_stack().GetAttributeSpan(args.size()), + /*initial_set=*/nullptr, true); + if (unknown_set != nullptr) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + } + std::vector> map_entries; map_entries.reserve(entry_count_); for (size_t i = 0; i < entry_count_; i += 1) { @@ -216,7 +236,7 @@ ::cel_base::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, if (cel_map == nullptr) { *result = CreateErrorValue(frame->arena(), "Failed to create map"); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } *result = CelValue::CreateMap(cel_map.get()); @@ -224,18 +244,18 @@ ::cel_base::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, // Pass object ownership to Arena. frame->arena()->Own(cel_map.release()); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } -::cel_base::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { +absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { if (frame->value_stack().size() < 2 * entry_count_) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "CreateStructStepForMap: stack undeflow"); } CelValue result; - ::cel_base::Status status = DoEvaluate(frame, &result); + absl::Status status = DoEvaluate(frame, &result); if (!status.ok()) { return status; } @@ -243,7 +263,7 @@ ::cel_base::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const frame->value_stack().Pop(2 * entry_count_); frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -260,20 +280,20 @@ cel_base::StatusOr> CreateCreateStructStep( create_struct_expr->message_name()); if (desc == nullptr) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Error configuring message creation: message descriptor not found"); } for (const auto& entry : create_struct_expr->entries()) { if (entry.field_key().empty()) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Error configuring message creation: field name missing"); } const FieldDescriptor* field_desc = desc->FindFieldByName(entry.field_key()); if (field_desc == nullptr) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Error configuring message creation: field name not found"); } entries.push_back({field_desc}); diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index c8fef5922..94749a6f3 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -9,6 +9,7 @@ #include "eval/eval/ident_step.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -16,12 +17,12 @@ namespace expr { namespace runtime { namespace { -using ::google::protobuf::Message; using ::google::protobuf::Arena; +using ::google::protobuf::Message; using testing::Eq; -using testing::Not; using testing::IsNull; +using testing::Not; using testing::Pointwise; using testutil::EqualsProto; @@ -32,7 +33,8 @@ using google::api::expr::v1alpha1::Expr; // builds message and runs it. cel_base::StatusOr RunExpression(absl::string_view field, const CelValue& value, - google::protobuf::Arena* arena) { + google::protobuf::Arena* arena, + bool enable_unknowns) { ExecutionPath path; Expr expr0; @@ -61,7 +63,8 @@ cel_base::StatusOr RunExpression(absl::string_view field, path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, + enable_unknowns); Activation activation; activation.InsertValue("message", value); @@ -69,9 +72,10 @@ cel_base::StatusOr RunExpression(absl::string_view field, } void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, - google::protobuf::Arena* arena, TestMessage* test_msg) { - auto status = RunExpression(field, value, arena); - ASSERT_TRUE(status.ok()); + google::protobuf::Arena* arena, TestMessage* test_msg, + bool enable_unknowns) { + auto status = RunExpression(field, value, arena, enable_unknowns); + ASSERT_OK(status); CelValue result = status.ValueOrDie(); ASSERT_TRUE(result.IsMessage()); @@ -85,13 +89,14 @@ void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, void RunExpressionAndGetMessage(absl::string_view field, std::vector values, - google::protobuf::Arena* arena, TestMessage* test_msg) { + google::protobuf::Arena* arena, TestMessage* test_msg, + bool enable_unknowns) { ContainerBackedListImpl cel_list(std::move(values)); CelValue value = CelValue::CreateList(&cel_list); - auto status = RunExpression(field, value, arena); - ASSERT_TRUE(status.ok()); + auto status = RunExpression(field, value, arena, enable_unknowns); + ASSERT_OK(status); CelValue result = status.ValueOrDie(); ASSERT_TRUE(result.IsMessage()); @@ -107,7 +112,7 @@ void RunExpressionAndGetMessage(absl::string_view field, // builds Map and runs it. cel_base::StatusOr RunCreateMapExpression( const std::vector> values, - google::protobuf::Arena* arena) { + google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; Activation activation; @@ -158,11 +163,14 @@ cel_base::StatusOr RunCreateMapExpression( path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, + enable_unknowns); return cel_expr.Evaluate(activation, arena); } -TEST(CreateCreateStructStepTest, TestEmptyMessageCreation) { +class CreateCreateStructStepTest : public testing::TestWithParam {}; + +TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; Expr expr1; @@ -172,17 +180,17 @@ TEST(CreateCreateStructStepTest, TestEmptyMessageCreation) { auto step_status = CreateCreateStructStep(create_struct, expr1.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); path.push_back(std::move(step_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, GetParam()); Activation activation; google::protobuf::Arena arena; auto status = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); CelValue result = status.ValueOrDie(); ASSERT_TRUE(result.IsMessage()); @@ -193,108 +201,126 @@ TEST(CreateCreateStructStepTest, TestEmptyMessageCreation) { ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); } +// Test message creation if unknown argument is passed +TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { + Arena arena; + TestMessage test_msg; + UnknownSet unknown_set; + + auto eval_status = RunExpression( + "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true); + ASSERT_OK(eval_status); + ASSERT_TRUE(eval_status->IsUnknownSet()); +} + // Test that fields of type bool are set correctly -TEST(CreateCreateStructStepTest, TestSetBoolField) { +TEST_P(CreateCreateStructStepTest, TestSetBoolField) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_value", CelValue::CreateBool(true), &arena, &test_msg)); + "bool_value", CelValue::CreateBool(true), &arena, &test_msg, GetParam())); ASSERT_EQ(test_msg.bool_value(), true); } // Test that fields of type int32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetInt32Field) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_value", CelValue::CreateInt64(1), &arena, &test_msg)); + "int32_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); ASSERT_EQ(test_msg.int32_value(), 1); } // Test that fields of type uint32_t are set correctly. -TEST(CreateCreateStructStepTest, TestSetUInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetUInt32Field) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint32_value", CelValue::CreateUint64(1), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("uint32_value", CelValue::CreateUint64(1), + &arena, &test_msg, GetParam())); ASSERT_EQ(test_msg.uint32_value(), 1); } // Test that fields of type int64_t are set correctly. -TEST(CreateCreateStructStepTest, TestSetInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetInt64Field) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_value", CelValue::CreateInt64(1), &arena, &test_msg)); + "int64_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); EXPECT_EQ(test_msg.int64_value(), 1); } // Test that fields of type uint64_t are set correctly. -TEST(CreateCreateStructStepTest, TestSetUInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetUInt64Field) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_value", CelValue::CreateUint64(1), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("uint64_value", CelValue::CreateUint64(1), + &arena, &test_msg, GetParam())); EXPECT_EQ(test_msg.uint64_value(), 1); } // Test that fields of type float are set correctly -TEST(CreateCreateStructStepTest, TestSetFloatField) { +TEST_P(CreateCreateStructStepTest, TestSetFloatField) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "float_value", CelValue::CreateDouble(2.0), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("float_value", CelValue::CreateDouble(2.0), + &arena, &test_msg, GetParam())); EXPECT_DOUBLE_EQ(test_msg.float_value(), 2.0); } // Test that fields of type double are set correctly -TEST(CreateCreateStructStepTest, TestSetDoubleField) { +TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "double_value", CelValue::CreateDouble(2.0), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("double_value", CelValue::CreateDouble(2.0), + &arena, &test_msg, GetParam())); EXPECT_DOUBLE_EQ(test_msg.double_value(), 2.0); } // Test that fields of type string are set correctly. -TEST(CreateCreateStructStepTest, TestSetStringField) { +TEST_P(CreateCreateStructStepTest, TestSetStringField) { const std::string kTestStr = "test"; Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_value", CelValue::CreateString(&kTestStr), &arena, &test_msg)); + "string_value", CelValue::CreateString(&kTestStr), &arena, &test_msg, + GetParam())); EXPECT_EQ(test_msg.string_value(), kTestStr); } // Test that fields of type bytes are set correctly. -TEST(CreateCreateStructStepTest, TestSetBytesField) { +TEST_P(CreateCreateStructStepTest, TestSetBytesField) { Arena arena; const std::string kTestStr = "test"; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bytes_value", CelValue::CreateBytes(&kTestStr), &arena, &test_msg)); + "bytes_value", CelValue::CreateBytes(&kTestStr), &arena, &test_msg, + GetParam())); EXPECT_EQ(test_msg.bytes_value(), kTestStr); } // Test that fields of type duration are set correctly. -TEST(CreateCreateStructStepTest, TestSetDurationField) { +TEST_P(CreateCreateStructStepTest, TestSetDurationField) { Arena arena; google::protobuf::Duration test_duration; @@ -304,12 +330,12 @@ TEST(CreateCreateStructStepTest, TestSetDurationField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "duration_value", CelValue::CreateDuration(&test_duration), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_THAT(test_msg.duration_value(), EqualsProto(test_duration)); } // Test that fields of type timestamp are set correctly. -TEST(CreateCreateStructStepTest, TestSetTimestampField) { +TEST_P(CreateCreateStructStepTest, TestSetTimestampField) { Arena arena; google::protobuf::Timestamp test_timestamp; @@ -319,12 +345,12 @@ TEST(CreateCreateStructStepTest, TestSetTimestampField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "timestamp_value", CelValue::CreateTimestamp(&test_timestamp), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_THAT(test_msg.timestamp_value(), EqualsProto(test_timestamp)); } // Test that fields of type Message are set correctly. -TEST(CreateCreateStructStepTest, TestSetMessageField) { +TEST_P(CreateCreateStructStepTest, TestSetMessageField) { Arena arena; // Create payload message and set some fields. @@ -336,23 +362,23 @@ TEST(CreateCreateStructStepTest, TestSetMessageField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "message_value", CelValue::CreateMessage(&orig_msg, &arena), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_THAT(test_msg.message_value(), EqualsProto(orig_msg)); } // Test that fields of type Message are set correctly. -TEST(CreateCreateStructStepTest, TestSetEnumField) { +TEST_P(CreateCreateStructStepTest, TestSetEnumField) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_EQ(test_msg.enum_value(), TestMessage::TEST_ENUM_2); } // Test that fields of type bool are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedBoolField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { Arena arena; TestMessage test_msg; @@ -362,13 +388,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedBoolField) { values.push_back(CelValue::CreateBool(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("bool_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "bool_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.bool_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { Arena arena; TestMessage test_msg; @@ -378,13 +404,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { values.push_back(CelValue::CreateInt64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("int32_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "int32_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.int32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { Arena arena; TestMessage test_msg; @@ -394,13 +420,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { values.push_back(CelValue::CreateUint64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint32_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "uint32_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.uint32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int64_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { Arena arena; TestMessage test_msg; @@ -410,13 +436,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { values.push_back(CelValue::CreateInt64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("int64_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "int64_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.int64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint64_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { Arena arena; TestMessage test_msg; @@ -426,13 +452,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { values.push_back(CelValue::CreateUint64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint64_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "uint64_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.uint64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type float are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedFloatField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { Arena arena; TestMessage test_msg; @@ -442,13 +468,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedFloatField) { values.push_back(CelValue::CreateDouble(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("float_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "float_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.float_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { Arena arena; TestMessage test_msg; @@ -458,13 +484,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { values.push_back(CelValue::CreateDouble(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("double_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "double_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.double_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedStringField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { Arena arena; TestMessage test_msg; @@ -474,13 +500,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedStringField) { values.push_back(CelValue::CreateString(&value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("string_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "string_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.string_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedBytesField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { Arena arena; TestMessage test_msg; @@ -490,14 +516,14 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedBytesField) { values.push_back(CelValue::CreateBytes(&value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("bytes_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "bytes_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.bytes_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type Message are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedMessageField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { Arena arena; TestMessage test_msg; @@ -509,15 +535,15 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedMessageField) { values.push_back(CelValue::CreateMessage(&value, &arena)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("message_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "message_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.message_list()[0], EqualsProto(kValues[0])); ASSERT_THAT(test_msg.message_list()[1], EqualsProto(kValues[1])); } // Test that fields of type map are set correctly -TEST(CreateCreateStructStepTest, TestSetStringMapField) { +TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { Arena arena; TestMessage test_msg; @@ -535,8 +561,8 @@ TEST(CreateCreateStructStepTest, TestSetStringMapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena, - &test_msg)); + "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, + GetParam())); ASSERT_EQ(test_msg.string_int32_map().size(), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[0]), 2); @@ -544,7 +570,7 @@ TEST(CreateCreateStructStepTest, TestSetStringMapField) { } // Test that fields of type map are set correctly -TEST(CreateCreateStructStepTest, TestSetInt64MapField) { +TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { Arena arena; TestMessage test_msg; @@ -562,8 +588,8 @@ TEST(CreateCreateStructStepTest, TestSetInt64MapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, - &test_msg)); + "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, + GetParam())); ASSERT_EQ(test_msg.int64_int32_map().size(), 2); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[0]), 1); @@ -571,7 +597,7 @@ TEST(CreateCreateStructStepTest, TestSetInt64MapField) { } // Test that fields of type map are set correctly -TEST(CreateCreateStructStepTest, TestSetUInt64MapField) { +TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { Arena arena; TestMessage test_msg; @@ -589,8 +615,8 @@ TEST(CreateCreateStructStepTest, TestSetUInt64MapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, - &test_msg)); + "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, + GetParam())); ASSERT_EQ(test_msg.uint64_int32_map().size(), 2); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[0]), 1); @@ -598,11 +624,11 @@ TEST(CreateCreateStructStepTest, TestSetUInt64MapField) { } // Test that Empty Map is created successfully. -TEST(CreateCreateStructStepTest, TestCreateEmptyMap) { +TEST_P(CreateCreateStructStepTest, TestCreateEmptyMap) { Arena arena; - auto status = RunCreateMapExpression({}, &arena); + auto status = RunCreateMapExpression({}, &arena, GetParam()); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); CelValue result_value = status.ValueOrDie(); ASSERT_TRUE(result_value.IsMap()); @@ -611,8 +637,29 @@ TEST(CreateCreateStructStepTest, TestCreateEmptyMap) { ASSERT_EQ(cel_map->size(), 0); } +// Test message creation if unknown argument is passed +TEST(CreateCreateStructStepTest, TestMapCreateWithUnknown) { + Arena arena; + UnknownSet unknown_set; + std::vector> entries; + + std::vector kKeys = {"test2", "test1"}; + + entries.push_back( + {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); + entries.push_back({CelValue::CreateString(&kKeys[1]), + CelValue::CreateUnknownSet(&unknown_set)}); + + auto status = RunCreateMapExpression(entries, &arena, true); + + ASSERT_OK(status); + + CelValue result_value = status.ValueOrDie(); + ASSERT_TRUE(result_value.IsUnknownSet()); +} + // Test that String Map is created successfully. -TEST(CreateCreateStructStepTest, TestCreateStringMap) { +TEST_P(CreateCreateStructStepTest, TestCreateStringMap) { Arena arena; std::vector> entries; @@ -624,9 +671,9 @@ TEST(CreateCreateStructStepTest, TestCreateStringMap) { entries.push_back( {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); - auto status = RunCreateMapExpression(entries, &arena); + auto status = RunCreateMapExpression(entries, &arena, GetParam()); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); CelValue result_value = status.ValueOrDie(); ASSERT_TRUE(result_value.IsMap()); @@ -645,6 +692,9 @@ TEST(CreateCreateStructStepTest, TestCreateStringMap) { EXPECT_EQ(lookup1.value().Int64OrDie(), 1); } +INSTANTIATE_TEST_SUITE_P(CombinedCreateStructTest, CreateCreateStructStepTest, + testing::Bool()); + } // namespace } // namespace runtime diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 5e0579aa4..29d601dd9 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -1,31 +1,147 @@ #include "eval/eval/evaluator_core.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "base/status_macros.h" +#include "base/statusor.h" + namespace google { namespace api { namespace expr { namespace runtime { +namespace { + +absl::Status CheckIterAccess(CelExpressionFlatEvaluationState* state, + const std::string& name) { + if (state->iter_stack().empty()) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrCat( + "Attempted to update iteration variable outside of comprehension.'", + name, "'")); + } + auto iter = state->iter_variable_names().find(name); + if (iter == state->iter_variable_names().end()) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrCat("Attempted to set unknown variable '", name, "'")); + } + + return absl::OkStatus(); +} + +} // namespace + +void ValueStack::Clear() { + for (auto& v : stack_) { + v = CelValue(); + } + for (auto& attr : attribute_stack_) { + attr = AttributeTrail(); + } + + current_size_ = 0; +} -using google::api::expr::v1alpha1::Expr; +CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( + size_t value_stack_size, const std::set& iter_variable_names, + google::protobuf::Arena* arena) + : value_stack_(value_stack_size), + iter_variable_names_(iter_variable_names), + arena_(arena) {} + +void CelExpressionFlatEvaluationState::Reset() { + iter_stack_.clear(); + value_stack_.Clear(); +} const ExpressionStep* ExecutionFrame::Next() { - size_t end_pos = execution_path_->size(); + size_t end_pos = execution_path_.size(); - if (pc_ < end_pos) return (*execution_path_)[pc_++].get(); + if (pc_ < end_pos) return execution_path_[pc_++].get(); if (pc_ > end_pos) { GOOGLE_LOG(ERROR) << "Attempting to step beyond the end of execution path."; } return nullptr; } +absl::Status ExecutionFrame::PushIterFrame() { + state_->iter_stack().push_back({}); + return absl::OkStatus(); +} + +absl::Status ExecutionFrame::PopIterFrame() { + if (state_->iter_stack().empty()) { + return absl::InternalError("Loop stack underflow."); + } + state_->iter_stack().pop_back(); + return absl::OkStatus(); +} + +absl::Status ExecutionFrame::SetIterVar(const std::string& name, + const CelValue& val) { + RETURN_IF_ERROR(CheckIterAccess(state_, name)); + state_->IterStackTop()[name] = val; + + return absl::OkStatus(); +} + +absl::Status ExecutionFrame::ClearIterVar(const std::string& name) { + RETURN_IF_ERROR(CheckIterAccess(state_, name)); + state_->IterStackTop()[name] = absl::nullopt; + return absl::OkStatus(); +} + +bool ExecutionFrame::GetIterVar(const std::string& name, CelValue* val) const { + absl::Status status = CheckIterAccess(state_, name); + if (!status.ok()) { + return false; + } + + for (auto iter = state_->iter_stack().rbegin(); + iter != state_->iter_stack().rend(); ++iter) { + auto& frame = *iter; + auto frame_iter = frame.find(name); + if (frame_iter != frame.end()) { + if (frame_iter->second.has_value()) { + *val = frame_iter->second.value(); + return true; + } + } + } + + return false; +} + +std::unique_ptr CelExpressionFlatImpl::InitializeState( + google::protobuf::Arena* arena) const { + return absl::make_unique( + path_.size(), iter_variable_names_, arena); +} + cel_base::StatusOr CelExpressionFlatImpl::Evaluate( - const BaseActivation& activation, google::protobuf::Arena* arena) const { - return Trace(activation, arena, CelEvaluationListener()); + const BaseActivation& activation, CelEvaluationState* state) const { + return Trace(activation, state, CelEvaluationListener()); } cel_base::StatusOr CelExpressionFlatImpl::Trace( - const BaseActivation& activation, google::protobuf::Arena* arena, + const BaseActivation& activation, CelEvaluationState* _state, CelEvaluationListener callback) const { - ExecutionFrame frame(&path_, activation, arena, max_iterations_); + auto state = down_cast(_state); + state->Reset(); + + // Using both unknown attribute patterns and unknown paths via FieldMask is + // not allowed. + if (activation.unknown_paths().paths_size() != 0 && + !activation.unknown_attribute_patterns().empty()) { + return absl::InvalidArgumentError( + "Attempting to evaluate expression with both unknown_paths and " + "unknown_attribute_patterns set in the Activation"); + } + + ExecutionFrame frame(path_, activation, max_iterations_, state, + enable_unknowns_, enable_unknown_function_results_); ValueStack* stack = &frame.value_stack(); size_t initial_stack_size = stack->size(); @@ -48,7 +164,7 @@ cel_base::StatusOr CelExpressionFlatImpl::Trace( "Try to disable short-circuiting."; continue; } - auto status2 = callback(expr->id(), stack->Peek(), arena); + auto status2 = callback(expr->id(), stack->Peek(), state->arena()); if (!status2.ok()) { return status2; } @@ -56,7 +172,7 @@ cel_base::StatusOr CelExpressionFlatImpl::Trace( size_t final_stack_size = stack->size(); if (initial_stack_size + 1 != final_stack_size || final_stack_size == 0) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Stack error during evaluation"); } CelValue value = stack->Peek(); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 50d2d4c0d..475613a0e 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -1,10 +1,19 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "absl/types/optional.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/unknowns_utility.h" #include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -26,7 +35,7 @@ class ExpressionStep { // interface. // ExpressionStep instances can in specific cases // modify execution order(perform jumps). - virtual cel_base::Status Evaluate(ExecutionFrame* context) const = 0; + virtual absl::Status Evaluate(ExecutionFrame* context) const = 0; // Returns corresponding expression object ID. // Requires that the input expression has IDs assigned to sub-expressions, @@ -40,9 +49,6 @@ class ExpressionStep { virtual bool ComesFromAst() const = 0; }; -// CelValue stack. -// Implementation is based on vector to allow passing parameters from -// stack as Span<>. using ExecutionPath = std::vector>; // CelValue stack. @@ -50,42 +56,144 @@ using ExecutionPath = std::vector>; // stack as Span<>. class ValueStack { public: - ValueStack() = default; + ValueStack(size_t max_size) : current_size_(0) { + stack_.resize(max_size); + attribute_stack_.resize(max_size); + } + + // Return the current stack size. + size_t size() const { return current_size_; } + + // Return the maximum size of the stack. + size_t max_size() const { return stack_.size(); } - // Stack size. - size_t size() const { return stack_.size(); } + // Returns true if stack is empty. + bool empty() const { return current_size_ == 0; } + + // Attributes stack size. + size_t attribute_size() const { return current_size_; } // Check that stack has enough elements. - bool HasEnough(size_t size) const { return stack_.size() >= size; } + bool HasEnough(size_t size) const { return current_size_ >= size; } + + // Dumps the entire stack state as is. + void Clear(); // Gets the last size elements of the stack. // Checking that stack has enough elements is caller's responsibility. // Please note that calls to Push may invalidate returned Span object. absl::Span GetSpan(size_t size) const { - return absl::Span(stack_.data() + stack_.size() - size, + if (!HasEnough(size)) { + GOOGLE_LOG(ERROR) << "Requested span size (" << size + << ") exceeds current stack size: " << current_size_; + } + return absl::Span(stack_.data() + current_size_ - size, size); } + // Gets the last size attribute trails of the stack. + // Checking that stack has enough elements is caller's responsibility. + // Please note that calls to Push may invalidate returned Span object. + absl::Span GetAttributeSpan(size_t size) const { + return absl::Span( + attribute_stack_.data() + current_size_ - size, size); + } + // Peeks the last element of the stack. // Checking that stack is not empty is caller's responsibility. - const CelValue& Peek() const { return stack_.back(); } + const CelValue& Peek() const { + if (empty()) { + GOOGLE_LOG(ERROR) << "Peeking on empty ValueStack"; + } + return stack_[current_size_ - 1]; + } + + // Peeks the last element of the attribute stack. + // Checking that stack is not empty is caller's responsibility. + const AttributeTrail& PeekAttribute() const { + if (empty()) { + GOOGLE_LOG(ERROR) << "Peeking on empty ValueStack"; + } + return attribute_stack_[current_size_ - 1]; + } // Clears the last size elements of the stack. // Checking that stack has enough elements is caller's responsibility. - void Pop(size_t size) { stack_.resize(stack_.size() - size); } + void Pop(size_t size) { + if (!HasEnough(size)) { + GOOGLE_LOG(ERROR) << "Trying to pop more elements (" << size + << ") than the current stack size: " << current_size_; + } + current_size_ -= size; + } // Put element on the top of the stack. - void Push(const CelValue& value) { stack_.push_back(value); } + void Push(const CelValue& value) { Push(value, AttributeTrail()); } + + void Push(const CelValue& value, AttributeTrail attribute) { + if (current_size_ >= stack_.size()) { + GOOGLE_LOG(ERROR) << "No room to push more elements on to ValueStack"; + } + stack_[current_size_] = value; + attribute_stack_[current_size_] = attribute; + current_size_++; + } // Replace element on the top of the stack. // Checking that stack is not empty is caller's responsibility. - void PopAndPush(const CelValue& value) { stack_.back() = value; } + void PopAndPush(const CelValue& value) { + PopAndPush(value, AttributeTrail()); + } + + // Replace element on the top of the stack. + // Checking that stack is not empty is caller's responsibility. + void PopAndPush(const CelValue& value, AttributeTrail attribute) { + if (empty()) { + GOOGLE_LOG(ERROR) << "Cannot PopAndPush on empty stack."; + } + stack_[current_size_ - 1] = value; + attribute_stack_[current_size_ - 1] = attribute; + } // Preallocate stack. - void Reserve(size_t size) { stack_.reserve(size); } + void Reserve(size_t size) { + stack_.reserve(size); + attribute_stack_.reserve(size); + } private: std::vector stack_; + std::vector attribute_stack_; + size_t current_size_; +}; + +class CelExpressionFlatEvaluationState : public CelEvaluationState { + public: + CelExpressionFlatEvaluationState( + size_t value_stack_size, const std::set& iter_variable_names, + google::protobuf::Arena* arena); + + void Reset(); + + ValueStack& value_stack() { return value_stack_; } + + std::vector>>& iter_stack() { + return iter_stack_; + } + + std::map>& IterStackTop() { + return iter_stack_[iter_stack().size() - 1]; + } + + std::set& iter_variable_names() { return iter_variable_names_; } + + google::protobuf::Arena* arena() { return arena_; } + + private: + ValueStack value_stack_; + std::set iter_variable_names_; + std::vector>> iter_stack_; + google::protobuf::Arena* arena_; }; // ExecutionFrame provides context for expression evaluation. @@ -95,68 +203,90 @@ class ExecutionFrame { // flat is the flattened sequence of execution steps that will be evaluated. // activation provides bindings between parameter names and values. // arena serves as allocation manager during the expression evaluation. - ExecutionFrame(const ExecutionPath* flat, const BaseActivation& activation, - google::protobuf::Arena* arena, int max_iterations) + + ExecutionFrame(const ExecutionPath& flat, const BaseActivation& activation, + int max_iterations, CelExpressionFlatEvaluationState* state, + bool enable_unknowns, bool enable_unknown_function_results) : pc_(0UL), execution_path_(flat), activation_(activation), - arena_(arena), + enable_unknowns_(enable_unknowns), + enable_unknown_function_results_(enable_unknown_function_results), + unknowns_utility_(&activation.unknown_attribute_patterns(), + state->arena()), max_iterations_(max_iterations), - iterations_(0) { - // Reserve space on stack to minimize reallocations - // on stack resize. - value_stack_.Reserve(flat->size()); - } + iterations_(0), + state_(state) {} // Returns next expression to evaluate. const ExpressionStep* Next(); // Intended for use only in conditionals. - cel_base::Status JumpTo(int offset) { + absl::Status JumpTo(int offset) { int new_pc = static_cast(pc_) + offset; - if (new_pc < 0 || new_pc > static_cast(execution_path_->size())) { - return cel_base::Status(cel_base::StatusCode::kInternal, + if (new_pc < 0 || new_pc > static_cast(execution_path_.size())) { + return absl::Status(absl::StatusCode::kInternal, absl::StrCat("Jump address out of range: position: ", pc_, ",offset: ", offset, - ", range: ", execution_path_->size())); + ", range: ", execution_path_.size())); } pc_ = static_cast(new_pc); - return cel_base::OkStatus(); + return absl::OkStatus(); } - ValueStack& value_stack() { return value_stack_; } + ValueStack& value_stack() { return state_->value_stack(); } + bool enable_unknowns() const { return enable_unknowns_; } + bool enable_unknown_function_results() const { + return enable_unknown_function_results_; + } - google::protobuf::Arena* arena() { return arena_; } + google::protobuf::Arena* arena() { return state_->arena(); } + const UnknownsUtility& unknowns_utility() const { return unknowns_utility_; } // Returns reference to Activation const BaseActivation& activation() const { return activation_; } - // Returns reference to iter_vars - std::map& iter_vars() { return iter_vars_; } + // Creates a new frame for iteration variables. + absl::Status PushIterFrame(); + + // Discards the top frame for iteration variables. + absl::Status PopIterFrame(); + + // Sets the value of an iteration variable + absl::Status SetIterVar(const std::string& name, const CelValue& val); + + // Clears the value of an iteration variable + absl::Status ClearIterVar(const std::string& name); + + // Gets the current value of an iteration variable. + // Returns false if the variable is not currently in use (Set has been called + // since init or last clear). + bool GetIterVar(const std::string& name, CelValue* val) const; // Increment iterations and return an error if the iteration budget is // exceeded - cel_base::Status IncrementIterations() { + absl::Status IncrementIterations() { if (max_iterations_ == 0) { - return cel_base::OkStatus(); + return absl::OkStatus(); } iterations_++; if (iterations_ >= max_iterations_) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Iteration budget exceeded"); } - return cel_base::OkStatus(); + return absl::OkStatus(); } private: size_t pc_; // pc_ - Program Counter. Current position on execution path. - const ExecutionPath* execution_path_; + const ExecutionPath& execution_path_; const BaseActivation& activation_; - ValueStack value_stack_; - google::protobuf::Arena* arena_; + bool enable_unknowns_; + bool enable_unknown_function_results_; + UnknownsUtility unknowns_utility_; const int max_iterations_; int iterations_; - std::map iter_vars_; // variables declared in the frame. + CelExpressionFlatEvaluationState* state_; }; // Implementation of the CelExpression that utilizes flattening @@ -168,22 +298,50 @@ class CelExpressionFlatImpl : public CelExpression { // flattened AST tree. Max iterations dictates the maximum number of // iterations in the comprehension expressions (use 0 to disable the upper // bound). - CelExpressionFlatImpl(const google::api::expr::v1alpha1::Expr*, ExecutionPath path, - int max_iterations) - : path_(std::move(path)), max_iterations_(max_iterations) {} + CelExpressionFlatImpl(const google::api::expr::v1alpha1::Expr* root_expr, + ExecutionPath path, int max_iterations, + std::set iter_variable_names, + bool enable_unknowns = false, + bool enable_unknown_function_results = false) + : path_(std::move(path)), + max_iterations_(max_iterations), + iter_variable_names_(std::move(iter_variable_names)), + enable_unknowns_(enable_unknowns), + enable_unknown_function_results_(enable_unknown_function_results) {} + + // Move-only + CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; + CelExpressionFlatImpl& operator=(const CelExpressionFlatImpl&) = delete; + + std::unique_ptr InitializeState( + google::protobuf::Arena* arena) const override; // Implementation of CelExpression evaluate method. cel_base::StatusOr Evaluate(const BaseActivation& activation, - google::protobuf::Arena* arena) const override; + google::protobuf::Arena* arena) const override { + return Evaluate(activation, InitializeState(arena).get()); + } + + cel_base::StatusOr Evaluate(const BaseActivation& activation, + CelEvaluationState* state) const override; // Implementation of CelExpression trace method. + cel_base::StatusOr Trace( + const BaseActivation& activation, google::protobuf::Arena* arena, + CelEvaluationListener callback) const override { + return Trace(activation, InitializeState(arena).get(), callback); + } + cel_base::StatusOr Trace(const BaseActivation& activation, - google::protobuf::Arena* arena, + CelEvaluationState* state, CelEvaluationListener callback) const override; private: const ExecutionPath path_; const int max_iterations_; + const std::set iter_variable_names_; + bool enable_unknowns_; + bool enable_unknown_function_results_; }; } // namespace runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index d8eae3cf5..70a32c41e 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -1,10 +1,13 @@ #include "eval/eval/evaluator_core.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "eval/compiler/flat_expr_builder.h" -#include "eval/public/builtin_func_registrar.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -13,15 +16,16 @@ namespace runtime { using ::google::api::expr::runtime::RegisterBuiltinFunctions; using testing::_; -using testing::Eq; // Optional ::testing aliases. Remove if unused. +using testing::Eq; +using testing::NotNull; // Fake expression implementation // Pushes int64_t(0) on top of value stack. class FakeConstExpressionStep : public ExpressionStep { public: - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { frame->value_stack().Push(CelValue::CreateInt64(0)); - return cel_base::OkStatus(); + return absl::OkStatus(); } int64_t id() const override { return 0; } @@ -33,13 +37,13 @@ class FakeConstExpressionStep : public ExpressionStep { // Increments argument on top of the stack. class FakeIncrementExpressionStep : public ExpressionStep { public: - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { CelValue value = frame->value_stack().Peek(); frame->value_stack().Pop(1); EXPECT_TRUE(value.IsInt64()); int64_t val = value.Int64OrDie(); frame->value_stack().Push(CelValue::CreateInt64(val + 1)); - return cel_base::OkStatus(); + return absl::OkStatus(); } int64_t id() const override { return 0; } @@ -47,6 +51,52 @@ class FakeIncrementExpressionStep : public ExpressionStep { bool ComesFromAst() const override { return true; } }; +// Test Value Stack Push/Pop operation +TEST(EvaluatorCoreTest, ValueStackPushPop) { + google::protobuf::Arena arena; + google::api::expr::v1alpha1::Expr expr; + expr.mutable_ident_expr()->set_name("name"); + CelAttribute attribute(expr, {}); + ValueStack stack(10); + stack.Push(CelValue::CreateInt64(1)); + stack.Push(CelValue::CreateInt64(2), AttributeTrail()); + stack.Push(CelValue::CreateInt64(3), AttributeTrail(expr, &arena)); + + ASSERT_EQ(stack.Peek().Int64OrDie(), 3); + ASSERT_THAT(stack.PeekAttribute().attribute(), NotNull()); + ASSERT_EQ(*stack.PeekAttribute().attribute(), attribute); + + stack.Pop(1); + + ASSERT_EQ(stack.Peek().Int64OrDie(), 2); + ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); + + stack.Pop(1); + + ASSERT_EQ(stack.Peek().Int64OrDie(), 1); + ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); +} + +// Test that inner stacks within value stack retain the equality of their sizes. +TEST(EvaluatorCoreTest, ValueStackBalanced) { + ValueStack stack(10); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.Push(CelValue::CreateInt64(1)); + ASSERT_EQ(stack.size(), stack.attribute_size()); + stack.Push(CelValue::CreateInt64(2), AttributeTrail()); + stack.Push(CelValue::CreateInt64(3), AttributeTrail()); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.PopAndPush(CelValue::CreateInt64(4), AttributeTrail()); + ASSERT_EQ(stack.size(), stack.attribute_size()); + stack.PopAndPush(CelValue::CreateInt64(5)); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.Pop(3); + ASSERT_EQ(stack.size(), stack.attribute_size()); +} + TEST(EvaluatorCoreTest, ExecutionFrameNext) { ExecutionPath path; auto const_step = absl::make_unique(); @@ -60,7 +110,8 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { auto dummy_expr = absl::make_unique(); Activation activation; - ExecutionFrame frame(&path, activation, nullptr, 0); + CelExpressionFlatEvaluationState state(path.size(), {}, nullptr); + ExecutionFrame frame(path, activation, 0, &state, false, false); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); @@ -68,6 +119,55 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { EXPECT_THAT(frame.Next(), Eq(nullptr)); } +// Test the set, get, and clear functions for "IterVar" on ExecutionFrame +TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { + const std::string test_key = "test_key"; + const int64_t test_value = 0xF00F00; + + Activation activation; + ExecutionPath path; + CelExpressionFlatEvaluationState state(path.size(), {test_key}, nullptr); + ExecutionFrame frame(path, activation, 0, &state, false, false); + + CelValue original = CelValue::CreateInt64(test_value); + CelValue result; + + ASSERT_OK(frame.PushIterFrame()); + + // Nothing is there yet + ASSERT_FALSE(frame.GetIterVar(test_key, &result)); + ASSERT_OK(frame.SetIterVar(test_key, original)); + + // Make sure its now there + ASSERT_TRUE(frame.GetIterVar(test_key, &result)); + + int64_t result_value; + ASSERT_TRUE(result.GetValue(&result_value)); + EXPECT_EQ(test_value, result_value); + + // Test that it goes away properly + ASSERT_OK(frame.ClearIterVar(test_key)); + ASSERT_FALSE(frame.GetIterVar(test_key, &result)); + + // Test that bogus names return the right thing + ASSERT_FALSE(frame.SetIterVar("foo", original).ok()); + ASSERT_FALSE(frame.ClearIterVar("bar").ok()); + + // Test error conditions for accesses outside of comprehension. + ASSERT_OK(frame.SetIterVar(test_key, original)); + ASSERT_OK(frame.PopIterFrame()); + + // Access on empty stack ok, but no value. + ASSERT_FALSE(frame.GetIterVar(test_key, &result)); + + // Pop empty stack + ASSERT_FALSE(frame.PopIterFrame().ok()); + + // Updates on empty stack not ok. + ASSERT_FALSE(frame.SetIterVar(test_key, original).ok()); + ASSERT_FALSE(frame.ClearIterVar(test_key).ok()); +} + TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { ExecutionPath path; auto const_step = absl::make_unique(); @@ -80,13 +180,13 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; google::protobuf::Arena arena; auto status = impl.Evaluate(activation, &arena); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); auto value = status.ValueOrDie(); EXPECT_TRUE(value.IsInt64()); @@ -158,10 +258,10 @@ TEST(EvaluatorCoreTest, TraceTest) { FlatExprBuilder builder; auto builtin_status = RegisterBuiltinFunctions(builder.GetRegistry()); - ASSERT_TRUE(builtin_status.ok()); + ASSERT_OK(builtin_status); builder.set_shortcircuiting(false); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -191,9 +291,9 @@ TEST(EvaluatorCoreTest, TraceTest) { activation, &arena, [&](int64_t expr_id, const CelValue& value, google::protobuf::Arena* arena) { callback.Call(expr_id, value, arena); - return ::cel_base::OkStatus(); + return absl::OkStatus(); }); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); } } // namespace runtime diff --git a/eval/eval/expression_build_warning.cc b/eval/eval/expression_build_warning.cc new file mode 100644 index 000000000..59a54651a --- /dev/null +++ b/eval/eval/expression_build_warning.cc @@ -0,0 +1,19 @@ +#include "eval/eval/expression_build_warning.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +absl::Status BuilderWarnings::AddWarning(const absl::Status& warning) { + if (fail_immediately_) { + return warning; + } + warnings_.push_back(warning); + return absl::OkStatus(); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/expression_build_warning.h b/eval/eval/expression_build_warning.h new file mode 100644 index 000000000..20575abe9 --- /dev/null +++ b/eval/eval/expression_build_warning.h @@ -0,0 +1,36 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ + +#include + +#include "absl/status/status.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Container for recording warnings. +class BuilderWarnings { + public: + explicit BuilderWarnings(bool fail_immediately = false) + : fail_immediately_(fail_immediately) {} + + // Add a warning. Returns the util:Status immediately if fail on warning is + // set. + absl::Status AddWarning(const absl::Status& warning); + + // Return the list of recorded warnings. + const std::vector& warnings() const { return warnings_; } + + private: + std::vector warnings_; + bool fail_immediately_; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ diff --git a/eval/eval/expression_build_warning_test.cc b/eval/eval/expression_build_warning_test.cc new file mode 100644 index 000000000..212b2e5ae --- /dev/null +++ b/eval/eval/expression_build_warning_test.cc @@ -0,0 +1,36 @@ +#include "eval/eval/expression_build_warning.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + + +TEST(BuilderWarnings, NoFailCollects) { + BuilderWarnings warnings(false); + + auto status = warnings.AddWarning(absl::InternalError("internal")); + EXPECT_TRUE(status.ok()); + auto status2 = warnings.AddWarning(absl::InternalError("internal error 2")); + EXPECT_TRUE(status2.ok()); + + EXPECT_THAT(warnings.warnings(), testing::SizeIs(2)); +} + +TEST(BuilderWarnings, FailReturnsStatus) { + BuilderWarnings warnings(true); + + EXPECT_EQ(warnings.AddWarning(absl::InternalError("internal")).code(), + absl::StatusCode::kInternal); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/field_access.cc b/eval/eval/field_access.cc index e7921a4f4..316698af1 100644 --- a/eval/eval/field_access.cc +++ b/eval/eval/field_access.cc @@ -3,10 +3,10 @@ #include #include "google/protobuf/map_field.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "internal/proto_util.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -73,7 +73,7 @@ class FieldAccessor { // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. // result pointer to object to store value in. - cel_base::Status CreateValueFromFieldAccessor(Arena* arena, CelValue* result) { + absl::Status CreateValueFromFieldAccessor(Arena* arena, CelValue* result) { switch (field_desc_->cpp_type()) { case FieldDescriptor::CPPTYPE_BOOL: { bool value = GetBool(); @@ -124,7 +124,7 @@ class FieldAccessor { *result = CelValue::CreateBytes(value); break; default: - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Error handling C++ string conversion"); } break; @@ -140,11 +140,11 @@ class FieldAccessor { break; } default: - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Unhandled C++ type conversion"); } - return cel_base::OkStatus(); + return absl::OkStatus(); } protected: @@ -322,7 +322,7 @@ class MessageRetrieverOp { } // namespace -cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg, +absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, CelValue* result) { @@ -330,7 +330,7 @@ cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg return accessor.CreateValueFromFieldAccessor(arena, result); } -cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, +absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result) { @@ -338,7 +338,7 @@ cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* m return accessor.CreateValueFromFieldAccessor(arena, result); } -cel_base::Status CreateValueFromMapValue(const google::protobuf::Message* msg, +absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const FieldDescriptor* desc, const MapValueRef* value_ref, google::protobuf::Arena* arena, CelValue* result) { @@ -700,33 +700,32 @@ class RepeatedFieldSetter : public FieldSetter { // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. // result pointer to object to store value in. -::cel_base::Status SetValueToSingleField(const CelValue& value, - const FieldDescriptor* desc, - Message* msg) { +absl::Status SetValueToSingleField(const CelValue& value, + const FieldDescriptor* desc, Message* msg) { ScalarFieldSetter setter(msg, desc); return (setter.SetFieldFromCelValue(value)) - ? ::cel_base::OkStatus() - : ::cel_base::InvalidArgumentError(absl::Substitute( + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( "Could not assign supplied argument to message \"$0\" field " "\"$1\" of type $2: type was $3", msg->GetDescriptor()->name(), desc->name(), desc->type_name(), absl::StrCat(value.type()))); } -::cel_base::Status AddValueToRepeatedField(const CelValue& value, - const FieldDescriptor* desc, - Message* msg) { +absl::Status AddValueToRepeatedField(const CelValue& value, + const FieldDescriptor* desc, + Message* msg) { RepeatedFieldSetter setter(msg, desc); return (setter.SetFieldFromCelValue(value)) - ? ::cel_base::OkStatus() - : ::cel_base::InvalidArgumentError(absl::Substitute( + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( "Could not add supplied argument to message \"$0\" field " "\"$1\".", msg->GetDescriptor()->name(), desc->name())); } -::cel_base::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const FieldDescriptor* desc, Message* msg) { +absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, + const FieldDescriptor* desc, Message* msg) { auto entry_msg = msg->GetReflection()->AddMessage(msg, desc); auto key_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(1); auto value_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(2); @@ -735,20 +734,20 @@ ::cel_base::Status AddValueToMapField(const CelValue& key, const CelValue& value ScalarFieldSetter value_setter(entry_msg, value_field_desc); if (!key_setter.SetFieldFromCelValue(key)) { - return ::cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( absl::Substitute("Could not assign supplied argument to message \"$0\" " "field \"$1\" map key.", msg->GetDescriptor()->name(), desc->name())); } if (!value_setter.SetFieldFromCelValue(value)) { - return ::cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( absl::Substitute("Could not assign supplied argument to message \"$0\" " "field \"$1\" map value.", msg->GetDescriptor()->name(), desc->name())); } - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/eval/field_access.h b/eval/eval/field_access.h index 45deb22f6..63ed38369 100644 --- a/eval/eval/field_access.h +++ b/eval/eval/field_access.h @@ -14,7 +14,7 @@ namespace runtime { // desc Descriptor of the field to access. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. -cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg, +absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, google::protobuf::Arena* arena, CelValue* result); @@ -25,7 +25,7 @@ cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg // arena Arena object to allocate result on, if needed. // index position in the repeated field. // result pointer to CelValue to store the result in. -cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, +absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result); @@ -37,7 +37,7 @@ cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* m // value_ref pointer to map value. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. -cel_base::Status CreateValueFromMapValue(const google::protobuf::Message* msg, +absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, const google::protobuf::MapValueRef* value_ref, google::protobuf::Arena* arena, CelValue* result); @@ -46,25 +46,25 @@ cel_base::Status CreateValueFromMapValue(const google::protobuf::Message* msg, // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. -::cel_base::Status SetValueToSingleField(const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); +absl::Status SetValueToSingleField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg); // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. -::cel_base::Status AddValueToRepeatedField(const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); +absl::Status AddValueToRepeatedField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg); // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. -::cel_base::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); +absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg); } // namespace runtime } // namespace expr diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index d0afdbc7b..76c787cce 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -5,11 +5,17 @@ #include #include +#include "google/protobuf/arena.h" #include "absl/strings/str_cat.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" #include "eval/eval/expression_step_base.h" #include "eval/public/cel_function_provider.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_function_result_set.h" +#include "eval/public/unknown_set.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -27,7 +33,10 @@ class AbstractFunctionStep : public ExpressionStepBase { AbstractFunctionStep(size_t num_arguments, int64_t expr_id) : ExpressionStepBase(expr_id), num_arguments_(num_arguments) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; + + absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + virtual cel_base::StatusOr ResolveFunction( absl::Span args, const ExecutionFrame* frame) const = 0; @@ -35,14 +44,18 @@ class AbstractFunctionStep : public ExpressionStepBase { size_t num_arguments_; }; -cel_base::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { - if (!frame->value_stack().HasEnough(num_arguments_)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); - } - +absl::Status AbstractFunctionStep::DoEvaluate(ExecutionFrame* frame, + CelValue* result) const { // Create Span object that contains input arguments to the function. auto input_args = frame->value_stack().GetSpan(num_arguments_); + const UnknownSet* unknown_set = nullptr; + if (frame->enable_unknowns()) { + auto input_attrs = frame->value_stack().GetAttributeSpan(num_arguments_); + unknown_set = frame->unknowns_utility().MergeUnknowns( + input_args, input_attrs, /*initial_set=*/nullptr, /*use_partial=*/true); + } + // Derived class resolves to a single function overload or none. auto status = ResolveFunction(input_args, frame); if (!status.ok()) { @@ -50,15 +63,23 @@ cel_base::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { } const CelFunction* matched_function = status.ValueOrDie(); - CelValue result = CelValue::CreateNull(); - // Overload found - if (matched_function != nullptr) { - cel_base::Status status = - matched_function->Evaluate(input_args, &result, frame->arena()); + if (matched_function != nullptr && unknown_set == nullptr) { + absl::Status status = + matched_function->Evaluate(input_args, result, frame->arena()); if (!status.ok()) { return status; } + if (frame->enable_unknown_function_results() && + IsUnknownFunctionResult(*result)) { + const auto* function_result = + google::protobuf::Arena::Create( + frame->arena(), matched_function->descriptor(), id(), + std::vector(input_args.begin(), input_args.end())); + const auto* unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownFunctionResultSet(function_result)); + *result = CelValue::CreateUnknownSet(unknown_set); + } } else { // No matching overloads. // We should not treat absense of overloads as non-recoverable error. @@ -67,20 +88,40 @@ cel_base::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { // should be propagated along execution path. for (const CelValue& arg : input_args) { if (arg.IsError()) { - result = arg; - break; + *result = arg; + return absl::OkStatus(); } } + + if (unknown_set) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + // If no errors in input args, create new CelError. - if (!result.IsError()) { - result = CreateNoMatchingOverloadError(frame->arena()); + if (!result->IsError()) { + *result = CreateNoMatchingOverloadError(frame->arena()); } } - frame->value_stack().Pop(input_args.length()); + return absl::OkStatus(); +} + +absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(num_arguments_)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + + CelValue result; + auto status = DoEvaluate(frame, &result); + if (!status.ok()) { + return status; + } + + frame->value_stack().Pop(num_arguments_); frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } class EagerFunctionStep : public AbstractFunctionStep { @@ -105,7 +146,7 @@ cel_base::StatusOr EagerFunctionStep::ResolveFunction( if (overload->MatchArguments(input_args)) { // More than one overload matches our arguments. if (matched_function != nullptr) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Cannot resolve overloads"); } @@ -159,7 +200,7 @@ cel_base::StatusOr LazyFunctionStep::ResolveFunction( if (overload != nullptr && overload->MatchArguments(input_args)) { // More than one overload matches our arguments. if (matched_function != nullptr) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Cannot resolve overloads"); } @@ -174,7 +215,8 @@ cel_base::StatusOr LazyFunctionStep::ResolveFunction( cel_base::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id, - const CelFunctionRegistry& function_registry) { + const CelFunctionRegistry& function_registry, + BuilderWarnings* builder_warnings) { bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); const std::string& name = call_expr->function(); @@ -192,15 +234,16 @@ cel_base::StatusOr> CreateFunctionStep( auto overloads = function_registry.FindOverloads(name, receiver_style, args); - if (!overloads.empty()) { - std::unique_ptr step = absl::make_unique( - std::move(overloads), num_args, expr_id); - return std::move(step); + // No overloads found. + if (overloads.empty()) { + RETURN_IF_ERROR(builder_warnings->AddWarning( + absl::Status(absl::StatusCode::kInvalidArgument, + "No overloads provided for FunctionStep creation"))); } - // No overloads found. - return ::cel_base::Status(cel_base::StatusCode::kInvalidArgument, - "No overloads provided for FunctionStep creation"); + std::unique_ptr step = absl::make_unique( + std::move(overloads), num_args, expr_id); + return std::move(step); } } // namespace runtime diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index 5ed34885d..12facbb63 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -2,6 +2,7 @@ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" #include "eval/public/activation.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" @@ -16,7 +17,8 @@ namespace runtime { // Looks up function registry using data provided through Call parameter. cel_base::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id, - const CelFunctionRegistry& function_registry); + const CelFunctionRegistry& function_registry, + BuilderWarnings* builder_warnings); } // namespace runtime } // namespace expr diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index d4659e988..8b55ab721 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -6,7 +6,15 @@ #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" +#include "eval/eval/ident_step.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_function_result_set.h" +#include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -15,8 +23,10 @@ namespace runtime { namespace { +using testing::ElementsAre; using testing::Eq; using testing::Not; +using testing::UnorderedElementsAre; using google::api::expr::v1alpha1::Expr; @@ -26,6 +36,7 @@ int GetExprId() { return id; } +// Simple function that takes no arguments and returns a constant value. class ConstFunction : public CelFunction { public: explicit ConstFunction(const CelValue& value, absl::string_view name) @@ -42,24 +53,30 @@ class ConstFunction : public CelFunction { return call; } - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (!args.empty()) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } *result = value_; - return cel_base::OkStatus(); + return absl::OkStatus(); } private: CelValue value_; }; +enum class ShouldReturnUnknown : bool { kYes = true, kNo = false }; + class AddFunction : public CelFunction { public: - AddFunction() : CelFunction(CreateDescriptor()) {} + AddFunction() + : CelFunction(CreateDescriptor()), should_return_unknown_(false) {} + explicit AddFunction(ShouldReturnUnknown should_return_unknown) + : CelFunction(CreateDescriptor()), + should_return_unknown_(static_cast(should_return_unknown)) {} static CelFunctionDescriptor CreateDescriptor() { return CelFunctionDescriptor{ @@ -75,23 +92,56 @@ class AddFunction : public CelFunction { return call; } - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 2 || !args[0].IsInt64() || !args[1].IsInt64()) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Mismatched arguments passed to method"); } + if (should_return_unknown_) { + *result = + CreateUnknownFunctionResultError(arena, "Add can't be resolved."); + return absl::OkStatus(); + } int64_t arg0 = args[0].Int64OrDie(); int64_t arg1 = args[1].Int64OrDie(); *result = CelValue::CreateInt64(arg0 + arg1); - return cel_base::OkStatus(); + return absl::OkStatus(); + } + + private: + bool should_return_unknown_; +}; + +class SinkFunction : public CelFunction { + public: + SinkFunction(CelValue::Type type) : CelFunction(CreateDescriptor(type)) {} + + static CelFunctionDescriptor CreateDescriptor(CelValue::Type type) { + return CelFunctionDescriptor{"Sink", false, {type}}; + } + + static Expr::Call MakeCall() { + Expr::Call call; + call.set_function("Sink"); + call.add_args(); + call.clear_target(); + return call; + } + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + // Return value is ignored. + *result = CelValue::CreateInt64(0); + return absl::OkStatus(); } }; // Create and initialize a registry with some default functions. void AddDefaults(CelFunctionRegistry& registry) { + static UnknownSet* unknown_set = new UnknownSet(); EXPECT_TRUE(registry .Register(absl::make_unique( CelValue::CreateInt64(3), "Const3")) @@ -100,11 +150,60 @@ void AddDefaults(CelFunctionRegistry& registry) { .Register(absl::make_unique( CelValue::CreateInt64(2), "Const2")) .ok()); + EXPECT_TRUE(registry + .Register(absl::make_unique( + CelValue::CreateUnknownSet(unknown_set), "ConstUnknown")) + .ok()); EXPECT_TRUE(registry.Register(absl::make_unique()).ok()); + + EXPECT_TRUE( + registry.Register(absl::make_unique(CelValue::Type::kList)) + .ok()); + + EXPECT_TRUE( + registry.Register(absl::make_unique(CelValue::Type::kMap)) + .ok()); + + EXPECT_TRUE( + registry + .Register(absl::make_unique(CelValue::Type::kMessage)) + .ok()); } -TEST(FunctionStepTest, SimpleFunctionTest) { +// Test common functions with varying levels of unknown support. +class FunctionStepTest + : public testing::TestWithParam { + public: + // underlying expression impl moves path + std::unique_ptr GetExpression(ExecutionPath&& path) { + bool unknowns; + bool unknown_function_results; + switch (GetParam()) { + case UnknownProcessingOptions::kAttributeAndFunction: + unknowns = true; + unknown_function_results = true; + break; + case UnknownProcessingOptions::kAttributeOnly: + unknowns = true; + unknown_function_results = false; + break; + case UnknownProcessingOptions::kDisabled: + unknowns = false; + unknown_function_results = false; + break; + } + return absl::make_unique( + &dummy_expr_, std::move(path), 0, std::set(), unknowns, + unknown_function_results); + } + + private: + Expr dummy_expr_; +}; + +TEST_P(FunctionStepTest, SimpleFunctionTest) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -113,27 +212,28 @@ TEST(FunctionStepTest, SimpleFunctionTest) { Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - EXPECT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -141,8 +241,9 @@ TEST(FunctionStepTest, SimpleFunctionTest) { EXPECT_THAT(value.Int64OrDie(), Eq(5)); } -TEST(FunctionStepTest, TestStackUnderflow) { +TEST_P(FunctionStepTest, TestStackUnderflow) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -152,40 +253,74 @@ TEST(FunctionStepTest, TestStackUnderflow) { Expr::Call call1 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); + auto status = impl->Evaluate(activation, &arena); EXPECT_FALSE(status.ok()); } -// Test factory method when empty overload set is provided. +// Test that creation fails if fail on warnings is set in the warnings +// collection. TEST(FunctionStepTest, TestNoOverloadsOnCreation) { - CelFunctionRegistry registry; // SetupRegistry(); + CelFunctionRegistry registry; + BuilderWarnings warnings(true); + Expr::Call call = ConstFunction::MakeCall("Const0"); // function step with empty overloads - auto step0_status = CreateFunctionStep(&call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call, GetExprId(), registry, &warnings); EXPECT_FALSE(step0_status.ok()); } +// Test that no overloads error is warned, actual error delayed to runtime by +// default. +TEST_P(FunctionStepTest, TestNoOverloadsOnCreationDelayedError) { + CelFunctionRegistry registry; + ExecutionPath path; + Expr::Call call = ConstFunction::MakeCall("Const0"); + BuilderWarnings warnings; + + // function step with empty overloads + auto step0_status = + CreateFunctionStep(&call, GetExprId(), registry, &warnings); + + EXPECT_TRUE(step0_status.ok()); + EXPECT_THAT(warnings.warnings(), testing::SizeIs(1)); + + path.push_back(std::move(step0_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + ASSERT_TRUE(value.IsError()); +} + // Test situation when no overloads match input arguments during evaluation. -TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { +TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -200,37 +335,39 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { // Add expects {int64_t, int64_t} but it's {int64_t, uint64_t}. Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); ASSERT_TRUE(value.IsError()); } - // Test situation when no overloads match input arguments during evaluation // and at least one of arguments is error. -TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { +TEST_P(FunctionStepTest, + TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -252,27 +389,28 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -280,49 +418,51 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); } -TEST(FunctionStepTest, LazyFunctionTest) { +TEST_P(FunctionStepTest, LazyFunctionTest) { ExecutionPath path; Activation activation; CelFunctionRegistry registry; + BuilderWarnings warnings; auto register0_status = registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const3")); - EXPECT_TRUE(register0_status.ok()); + ASSERT_OK(register0_status); auto insert0_status = activation.InsertFunction( absl::make_unique(CelValue::CreateInt64(3), "Const3")); - EXPECT_TRUE(insert0_status.ok()); + ASSERT_OK(insert0_status); auto register1_status = registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const2")); - EXPECT_TRUE(register1_status.ok()); + ASSERT_OK(register1_status); auto insert1_status = activation.InsertFunction( absl::make_unique(CelValue::CreateInt64(2), "Const2")); - EXPECT_TRUE(insert1_status.ok()); - EXPECT_TRUE(registry.Register(absl::make_unique()).ok()); + ASSERT_OK(insert1_status); + ASSERT_OK(registry.Register(absl::make_unique())); Expr::Call call1 = ConstFunction::MakeCall("Const3"); Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - EXPECT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -332,12 +472,14 @@ TEST(FunctionStepTest, LazyFunctionTest) { // Test situation when no overloads match input arguments during evaluation // and at least one of arguments is error. -TEST(FunctionStepTest, - TestNoMatchingOverloadsDuringEvaluationErrorForwardingLazy) { +TEST_P(FunctionStepTest, + TestNoMatchingOverloadsDuringEvaluationErrorForwardingLazy) { ExecutionPath path; Activation activation; google::protobuf::Arena arena; CelFunctionRegistry registry; + BuilderWarnings warnings; + AddDefaults(registry); CelError error0; @@ -346,46 +488,483 @@ TEST(FunctionStepTest, // Constants have ERROR type, while AddFunction expects INT. auto register0_status = registry.RegisterLazyFunction( ConstFunction::CreateDescriptor("ConstError1")); - ASSERT_TRUE(register0_status.ok()); + ASSERT_OK(register0_status); auto insert0_status = activation.InsertFunction(absl::make_unique( CelValue::CreateError(&error0), "ConstError1")); - ASSERT_TRUE(insert0_status.ok()); + ASSERT_OK(insert0_status); auto register1_status = registry.RegisterLazyFunction( ConstFunction::CreateDescriptor("ConstError2")); - ASSERT_TRUE(register1_status.ok()); + ASSERT_OK(register1_status); auto insert1_status = activation.InsertFunction(absl::make_unique( CelValue::CreateError(&error1), "ConstError2")); - ASSERT_TRUE(insert1_status.ok()); + ASSERT_OK(insert1_status); Expr::Call call1 = ConstFunction::MakeCall("ConstError1"); Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); + std::unique_ptr impl = GetExpression(std::move(path)); + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsError()); + EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); +} + +std::string TestNameFn(testing::TestParamInfo opt) { + switch (opt.param) { + case UnknownProcessingOptions::kDisabled: + return "disabled"; + case UnknownProcessingOptions::kAttributeOnly: + return "attribute_only"; + case UnknownProcessingOptions::kAttributeAndFunction: + return "attribute_and_function"; + } +} - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); +INSTANTIATE_TEST_SUITE_P( + UnknownSupport, FunctionStepTest, + testing::Values(UnknownProcessingOptions::kDisabled, + UnknownProcessingOptions::kAttributeOnly, + UnknownProcessingOptions::kAttributeAndFunction), + &TestNameFn); + +class FunctionStepTestUnknowns + : public testing::TestWithParam { + public: + std::unique_ptr GetExpression(ExecutionPath&& path) { + bool unknown_functions; + switch (GetParam()) { + case UnknownProcessingOptions::kAttributeAndFunction: + unknown_functions = true; + break; + default: + unknown_functions = false; + break; + } + return absl::make_unique(&expr, std::move(path), 0, + std::set(), + true, unknown_functions); + } + + private: + Expr expr; +}; + +TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + AddDefaults(registry); + + Expr::Call call1 = ConstFunction::MakeCall("Const3"); + Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); +} + +TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + AddDefaults(registry); + + // Build the expression path that corresponds to CEL expression + // "sink(param)". + Expr::Ident ident1; + ident1.set_name("param"); + Expr::Call call1 = SinkFunction::MakeCall(); + + auto step0_status = CreateIdentStep(&ident1, GetExprId()); + auto step1_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + TestMessage msg; + google::protobuf::Arena arena; + activation.InsertValue("param", CelValue::CreateMessage(&msg, &arena)); + CelAttributePattern pattern( + "param", + {CelAttributeQualifierPattern::Create(CelValue::CreateBool(true))}); + + // Set attribute pattern that marks attribute "param[true]" as unknown. + // It should result in "param" being handled as partially unknown, which is + // is handled as fully unknown when used as function input argument. + activation.set_unknown_attribute_patterns({pattern}); + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); +} + +TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + AddDefaults(registry); + + CelError error0; + CelValue error_value = CelValue::CreateError(&error0); + + ASSERT_TRUE( + registry + .Register(absl::make_unique(error_value, "ConstError")) + .ok()); + + Expr::Call call1 = ConstFunction::MakeCall("ConstError"); + Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsError()); + // Making sure we propagate the error. + ASSERT_EQ(value.ErrorOrDie(), error_value.ErrorOrDie()); +} + +INSTANTIATE_TEST_SUITE_P( + UnknownFunctionSupport, FunctionStepTestUnknowns, + testing::Values(UnknownProcessingOptions::kAttributeOnly, + UnknownProcessingOptions::kAttributeAndFunction), + &TestNameFn); + +MATCHER_P2(IsAdd, a, b, "") { + const UnknownFunctionResult* result = arg; + return result->arguments().size() == 2 && + result->arguments().at(0).IsInt64() && + result->arguments().at(1).IsInt64() && + result->arguments().at(0).Int64OrDie() == a && + result->arguments().at(1).Int64OrDie() == b && + result->descriptor().name() == "_+_"; +} + +TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + Expr::Call call1 = ConstFunction::MakeCall("Const2"); + Expr::Call call2 = ConstFunction::MakeCall("Const3"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl.Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); + // Arguments captured. + EXPECT_THAT(value.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(IsAdd(2, 3))); +} + +TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + // Add(Add(2, 3), Add(2, 3)) + + Expr::Call call1 = ConstFunction::MakeCall("Const2"); + Expr::Call call2 = ConstFunction::MakeCall("Const3"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step3_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step4_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step5_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step6_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + ASSERT_OK(step3_status); + ASSERT_OK(step4_status); + ASSERT_OK(step5_status); + ASSERT_OK(step6_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + path.push_back(std::move(step3_status.ValueOrDie())); + path.push_back(std::move(step4_status.ValueOrDie())); + path.push_back(std::move(step5_status.ValueOrDie())); + path.push_back(std::move(step6_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; auto status = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); + // Arguments captured. + EXPECT_THAT(value.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(IsAdd(2, 3))); +} + +TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + // Add(Add(2, 3), Add(3, 2)) + + Expr::Call call1 = ConstFunction::MakeCall("Const2"); + Expr::Call call2 = ConstFunction::MakeCall("Const3"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step3_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step4_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step5_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step6_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + ASSERT_OK(step3_status); + ASSERT_OK(step4_status); + ASSERT_OK(step5_status); + ASSERT_OK(step6_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + path.push_back(std::move(step3_status.ValueOrDie())); + path.push_back(std::move(step4_status.ValueOrDie())); + path.push_back(std::move(step5_status.ValueOrDie())); + path.push_back(std::move(step6_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl.Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()) << value.ErrorOrDie()->ToString(); + // Arguments captured. + EXPECT_THAT(value.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + UnorderedElementsAre(IsAdd(2, 3), IsAdd(3, 2))); +} + +TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + CelError error0; + CelValue error_value = CelValue::CreateError(&error0); + UnknownSet unknown_set; + CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); + + ASSERT_OK(registry.Register( + absl::make_unique(error_value, "ConstError"))); + ASSERT_OK(registry.Register( + absl::make_unique(unknown_value, "ConstUnknown"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + Expr::Call call1 = ConstFunction::MakeCall("ConstError"); + Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl.Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); ASSERT_TRUE(value.IsError()); - EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); + // Making sure we propagate the error. + ASSERT_EQ(value.ErrorOrDie(), error_value.ErrorOrDie()); } } // namespace diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index 9a7565378..178941c71 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -1,6 +1,10 @@ #include "eval/eval/ident_step.h" -#include "eval/eval/expression_step_base.h" + +#include "google/protobuf/arena.h" #include "absl/strings/substitute.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -13,41 +17,69 @@ class IdentStep : public ExpressionStepBase { IdentStep(absl::string_view name, int64_t expr_id) : ExpressionStepBase(expr_id), name_(name) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: + void DoEvaluate(ExecutionFrame* frame, CelValue* result, + AttributeTrail* trail) const; + std::string name_; }; -cel_base::Status IdentStep::Evaluate(ExecutionFrame* frame) const { - CelValue result; - auto it = frame->iter_vars().find(name_); - if (it != frame->iter_vars().end()) { - result = it->second; - } else { - auto value = frame->activation().FindValue(name_, frame->arena()); +void IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, + AttributeTrail* trail) const { + // Special case - iterator looked up in + if (frame->GetIterVar(name_, result)) { + return; + } + + auto value = frame->activation().FindValue(name_, frame->arena()); + { // We handle masked unknown paths for the sake of uniformity, although it is // better not to bind unknown values to activation in first place. + // TODO(issues/41) Deprecate this style of unknowns handling after + // Unknowns are properly supported. bool unknown_value = frame->activation().IsPathUnknown(name_); - if (!unknown_value) { - if (value.has_value()) { - result = value.value(); - } else { - result = CreateErrorValue( - frame->arena(), - absl::Substitute("No value with name \"$0\" found in Activation", - name_)); - } - } else { - result = CreateUnknownValueError(frame->arena(), name_); + if (unknown_value) { + *result = CreateUnknownValueError(frame->arena(), name_); + return; } } - frame->value_stack().Push(result); + if (frame->enable_unknowns()) { + google::api::expr::v1alpha1::Expr expr; + expr.mutable_ident_expr()->set_name(name_); + *trail = AttributeTrail(expr, frame->arena()); + + if (frame->unknowns_utility().CheckForUnknown(*trail, false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({trail->attribute()})); + *result = CelValue::CreateUnknownSet(unknown_set); + return; + } + } + + if (value.has_value()) { + *result = value.value(); + } else { + *result = CreateErrorValue( + frame->arena(), + absl::Substitute("No value with name \"$0\" found in Activation", + name_)); + } +} + +absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { + CelValue result; + AttributeTrail trail; + + DoEvaluate(frame, &result, &trail); + + frame->value_stack().Push(result, trail); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index 5b7a8a9f3..a77744cbe 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -1,9 +1,10 @@ #include "eval/eval/ident_step.h" -#include "eval/eval/evaluator_core.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/eval/evaluator_core.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -24,14 +25,14 @@ TEST(IdentStepTest, TestIdentStep) { ident_expr->set_name("name0"); auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); ExecutionPath path; path.push_back(std::move(step_status.ValueOrDie())); auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; Arena arena; @@ -39,7 +40,7 @@ TEST(IdentStepTest, TestIdentStep) { activation.InsertValue("name0", CelValue::CreateString(&value)); auto status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); CelValue result = status0.ValueOrDie(); @@ -53,40 +54,40 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { ident_expr->set_name("name0"); auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); ExecutionPath path; path.push_back(std::move(step_status.ValueOrDie())); auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; Arena arena; std::string value("test"); auto status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); CelValue result = status0.ValueOrDie(); ASSERT_TRUE(result.IsError()); } -TEST(IdentStepTest, TestIdentStepUnknownValue) { +TEST(IdentStepTest, TestIdentStepUnknownValueError) { Expr expr; auto ident_expr = expr.mutable_ident_expr(); ident_expr->set_name("name0"); auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); ExecutionPath path; path.push_back(std::move(step_status.ValueOrDie())); auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; Arena arena; @@ -94,7 +95,7 @@ TEST(IdentStepTest, TestIdentStepUnknownValue) { activation.InsertValue("name0", CelValue::CreateString(&value)); auto status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); CelValue result = status0.ValueOrDie(); @@ -106,7 +107,7 @@ TEST(IdentStepTest, TestIdentStepUnknownValue) { activation.set_unknown_paths(unknown_mask); status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); result = status0.ValueOrDie(); @@ -115,6 +116,50 @@ TEST(IdentStepTest, TestIdentStepUnknownValue) { EXPECT_THAT(GetUnknownPathsSetOrDie(result), Eq(std::set({"name0"}))); } +TEST(IdentStepTest, TestIdentStepUnknownAttribute) { + Expr expr; + auto ident_expr = expr.mutable_ident_expr(); + ident_expr->set_name("name0"); + + auto step_status = CreateIdentStep(ident_expr, expr.id()); + ASSERT_OK(step_status); + + ExecutionPath path; + path.push_back(std::move(step_status.ValueOrDie())); + + auto dummy_expr = absl::make_unique(); + + // Expression with unknowns enabled. + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, true); + + Activation activation; + Arena arena; + std::string value("test"); + + activation.InsertValue("name0", CelValue::CreateString(&value)); + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern("name_bad", {})); + + activation.set_unknown_attribute_patterns(unknown_patterns); + auto status0 = impl.Evaluate(activation, &arena); + ASSERT_OK(status0); + + CelValue result = status0.ValueOrDie(); + + ASSERT_TRUE(result.IsString()); + EXPECT_THAT(result.StringOrDie().value(), Eq("test")); + + unknown_patterns.push_back(CelAttributePattern("name0", {})); + + activation.set_unknown_attribute_patterns(unknown_patterns); + status0 = impl.Evaluate(activation, &arena); + ASSERT_OK(status0); + + result = status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); +} + } // namespace } // namespace runtime diff --git a/eval/eval/jump_step.cc b/eval/eval/jump_step.cc index 0988dbda1..908fea38c 100644 --- a/eval/eval/jump_step.cc +++ b/eval/eval/jump_step.cc @@ -14,7 +14,7 @@ class JumpStep : public JumpStepBase { JumpStep(absl::optional jump_offset, int64_t expr_id) : JumpStepBase(jump_offset, expr_id) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { return Jump(frame); } }; @@ -28,10 +28,10 @@ class CondJumpStep : public JumpStepBase { jump_condition_(jump_condition), leave_on_stack_(leave_on_stack) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { // Peek the top value if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue value = frame->value_stack().Peek(); @@ -44,7 +44,7 @@ class CondJumpStep : public JumpStepBase { return Jump(frame); } - return cel_base::OkStatus(); + return absl::OkStatus(); } private: @@ -57,15 +57,16 @@ class BoolCheckJumpStep : public JumpStepBase { // Checks if the top value is a boolean: // - no-op if it is a boolean // - jump to the label if it is an error value + // - jump to the label if it is unknown value // - jump to the label if it is neither an error nor a boolean, pops it and // pushes "no matching overload" error BoolCheckJumpStep(absl::optional jump_offset, int64_t expr_id) : JumpStepBase(jump_offset, expr_id) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { // Peek the top value if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue value = frame->value_stack().Peek(); @@ -74,13 +75,17 @@ class BoolCheckJumpStep : public JumpStepBase { return Jump(frame); } + if (value.IsUnknownSet()) { + return Jump(frame); + } + if (!value.IsBool()) { CelValue error_value = CreateNoMatchingOverloadError(frame->arena()); frame->value_stack().PopAndPush(error_value); return Jump(frame); } - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -109,7 +114,7 @@ cel_base::StatusOr> CreateJumpStep( // Factory method for Conditional Jump step. // Conditional Jump requires a value to sit on the stack. -// If this value is an error, a jump is performed. +// If this value is an error or unknown, a jump is performed. cel_base::StatusOr> CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id) { std::unique_ptr step = @@ -118,6 +123,9 @@ cel_base::StatusOr> CreateBoolCheckJumpStep( return std::move(step); } +// TODO(issues/41) Make sure Unknowns are properly supported by ternary +// operation. + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/eval/jump_step.h b/eval/eval/jump_step.h index fd474b763..3515c8bb8 100644 --- a/eval/eval/jump_step.h +++ b/eval/eval/jump_step.h @@ -19,9 +19,9 @@ class JumpStepBase : public ExpressionStepBase { void set_jump_offset(int offset) { jump_offset_ = offset; } - cel_base::Status Jump(ExecutionFrame* frame) const { + absl::Status Jump(ExecutionFrame* frame) const { if (!jump_offset_.has_value()) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Jump offset not set"); + return absl::Status(absl::StatusCode::kInternal, "Jump offset not set"); } return frame->JumpTo(jump_offset_.value()); } diff --git a/eval/eval/logic_step.cc b/eval/eval/logic_step.cc index 7103fb0ab..a267eefda 100644 --- a/eval/eval/logic_step.cc +++ b/eval/eval/logic_step.cc @@ -1,7 +1,9 @@ #include "eval/eval/logic_step.h" -#include "eval/eval/expression_step_base.h" #include "absl/strings/str_cat.h" +#include "eval/eval/expression_step_base.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -20,10 +22,10 @@ class LogicalOpStep : public ExpressionStepBase { shortcircuit_ = (op_type_ == OpType::OR); } - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status Calculate(ExecutionFrame* frame, absl::Span args, + absl::Status Calculate(ExecutionFrame* frame, absl::Span args, CelValue* result) const { bool bool_args[2]; bool has_bool_args[2]; @@ -32,7 +34,7 @@ class LogicalOpStep : public ExpressionStepBase { has_bool_args[i] = args[i].GetValue(bool_args + i); if (has_bool_args[i] && shortcircuit_ == bool_args[i]) { *result = CelValue::CreateBool(bool_args[i]); - return cel_base::OkStatus(); + return absl::OkStatus(); } } @@ -40,34 +42,51 @@ class LogicalOpStep : public ExpressionStepBase { switch (op_type_) { case OpType::AND: *result = CelValue::CreateBool(bool_args[0] && bool_args[1]); - return cel_base::OkStatus(); + return absl::OkStatus(); break; case OpType::OR: *result = CelValue::CreateBool(bool_args[0] || bool_args[1]); - return cel_base::OkStatus(); + return absl::OkStatus(); break; } } + // As opposed to regular function, logical operation treat Unknowns with + // higher precedence than error. This is due to the fact that after Unknown + // is resolved to actual value, it may shortcircuit and thus hide the error. + if (frame->enable_unknowns()) { + // Check if unknown? + const UnknownSet* unknown_set = + frame->unknowns_utility().MergeUnknowns(args, + /*initial_set=*/nullptr); + + if (unknown_set) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + } + if (args[0].IsError()) { *result = args[0]; + return absl::OkStatus(); } else if (args[1].IsError()) { *result = args[1]; - } else { - *result = CreateNoMatchingOverloadError(frame->arena()); + return absl::OkStatus(); } - return cel_base::OkStatus(); + // Fallback. + *result = CreateNoMatchingOverloadError(frame->arena()); + return absl::OkStatus(); } const OpType op_type_; bool shortcircuit_; }; -cel_base::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { +absl::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { // Must have 2 or more values on the stack. if (!frame->value_stack().HasEnough(2)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } // Create Span object that contains input arguments to the function. @@ -88,6 +107,9 @@ cel_base::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { } // namespace +// TODO(issues/41) Make sure Unknowns are properly supported by ternary +// operation. + // Factory method for "And" Execution step cel_base::StatusOr> CreateAndStep(int64_t expr_id) { std::unique_ptr step = diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc new file mode 100644 index 000000000..e9b410b1e --- /dev/null +++ b/eval/eval/logic_step_test.cc @@ -0,0 +1,319 @@ +#include "eval/eval/logic_step.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/eval/ident_step.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::api::expr::v1alpha1::Expr; + +using google::protobuf::Arena; +using testing::Eq; +class LogicStepTest : public testing::TestWithParam { + public: + absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, bool is_or, + CelValue* result, bool enable_unknown) { + Expr expr0; + auto ident_expr0 = expr0.mutable_ident_expr(); + ident_expr0->set_name("name0"); + + Expr expr1; + auto ident_expr1 = expr1.mutable_ident_expr(); + ident_expr1->set_name("name1"); + + ExecutionPath path; + + auto step_status = CreateIdentStep(ident_expr0, expr0.id()); + if (!step_status.ok()) return step_status.status(); + + path.push_back(std::move(step_status).ValueOrDie()); + + step_status = CreateIdentStep(ident_expr1, expr1.id()); + if (!step_status.ok()) return step_status.status(); + + path.push_back(std::move(step_status).ValueOrDie()); + + step_status = (is_or) ? CreateOrStep(2) : CreateAndStep(2); + if (!step_status.ok()) return step_status.status(); + + path.push_back(std::move(step_status).ValueOrDie()); + + auto dummy_expr = absl::make_unique(); + + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, + enable_unknown); + + Activation activation; + std::string value("test"); + + activation.InsertValue("name0", arg0); + activation.InsertValue("name1", arg1); + auto status0 = impl.Evaluate(activation, &arena_); + if (!status0.ok()) return status0.status(); + + *result = status0.ValueOrDie(); + return absl::OkStatus(); + } + + private: + Arena arena_; +}; + +TEST_P(LogicStepTest, TestAndLogic) { + CelValue result; + absl::Status status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(true), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); +} + +TEST_P(LogicStepTest, TestOrLogic) { + CelValue result; + absl::Status status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(CelValue::CreateBool(false), + CelValue::CreateBool(true), true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); +} + +TEST_P(LogicStepTest, TestAndLogicErrorHandling) { + CelValue result; + CelError error; + CelValue error_value = CelValue::CreateError(&error); + absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(true), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(true), error_value, false, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(false), error_value, false, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = EvaluateLogic(error_value, CelValue::CreateBool(false), false, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); +} + +TEST_P(LogicStepTest, TestOrLogicErrorHandling) { + CelValue result; + CelError error; + CelValue error_value = CelValue::CreateError(&error); + absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(false), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(false), error_value, true, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(true), error_value, true, &result, + GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(error_value, CelValue::CreateBool(true), true, &result, + GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); +} + +TEST_F(LogicStepTest, TestAndLogicUnknownHandling) { + CelValue result; + UnknownSet unknown_set; + CelError cel_error; + CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); + CelValue error_value = CelValue::CreateError(&cel_error); + absl::Status status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), + false, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = EvaluateLogic(unknown_value, CelValue::CreateBool(false), false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = EvaluateLogic(error_value, unknown_value, false, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(unknown_value, error_value, false, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + Expr expr0; + auto ident_expr0 = expr0.mutable_ident_expr(); + ident_expr0->set_name("name0"); + + Expr expr1; + auto ident_expr1 = expr1.mutable_ident_expr(); + ident_expr1->set_name("name1"); + + CelAttribute attr0(expr0, {}), attr1(expr1, {}); + UnknownAttributeSet unknown_attr_set0({&attr0}); + UnknownAttributeSet unknown_attr_set1({&attr1}); + UnknownSet unknown_set0(unknown_attr_set0); + UnknownSet unknown_set1(unknown_attr_set1); + + EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + + status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), + CelValue::CreateUnknownSet(&unknown_set1), false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + ASSERT_THAT( + result.UnknownSetOrDie()->unknown_attributes().attributes().size(), + Eq(2)); +} + +TEST_F(LogicStepTest, TestOrLogicUnknownHandling) { + CelValue result; + UnknownSet unknown_set; + CelError cel_error; + CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); + CelValue error_value = CelValue::CreateError(&cel_error); + absl::Status status = EvaluateLogic( + unknown_value, CelValue::CreateBool(false), true, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(unknown_value, error_value, true, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(error_value, unknown_value, true, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + Expr expr0; + auto ident_expr0 = expr0.mutable_ident_expr(); + ident_expr0->set_name("name0"); + + Expr expr1; + auto ident_expr1 = expr1.mutable_ident_expr(); + ident_expr1->set_name("name1"); + + CelAttribute attr0(expr0, {}), attr1(expr1, {}); + UnknownAttributeSet unknown_attr_set0({&attr0}); + UnknownAttributeSet unknown_attr_set1({&attr1}); + + UnknownSet unknown_set0(unknown_attr_set0); + UnknownSet unknown_set1(unknown_attr_set1); + + EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + + status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), + CelValue::CreateUnknownSet(&unknown_set1), true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + ASSERT_THAT( + result.UnknownSetOrDie()->unknown_attributes().attributes().size(), + Eq(2)); +} + +INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 60707bf68..effae14e0 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -1,9 +1,11 @@ #include "eval/eval/select_step.h" + +#include "absl/strings/str_cat.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/eval/field_access.h" #include "eval/eval/field_backed_list_impl.h" #include "eval/eval/field_backed_map_impl.h" -#include "absl/strings/str_cat.h" namespace google { namespace api { @@ -12,9 +14,9 @@ namespace runtime { namespace { -using google::protobuf::Reflection; using google::protobuf::Descriptor; using google::protobuf::FieldDescriptor; +using google::protobuf::Reflection; // SelectStep performs message field access specified by Expr::Select // message. @@ -27,10 +29,10 @@ class SelectStep : public ExpressionStepBase { test_field_presence_(test_field_presence), select_path_(select_path) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status CreateValueFromField(const google::protobuf::Message* message, + absl::Status CreateValueFromField(const google::protobuf::Message* msg, google::protobuf::Arena* arena, CelValue* result) const; @@ -39,7 +41,7 @@ class SelectStep : public ExpressionStepBase { std::string select_path_; }; -cel_base::Status SelectStep::CreateValueFromField(const google::protobuf::Message* msg, +absl::Status SelectStep::CreateValueFromField(const google::protobuf::Message* msg, google::protobuf::Arena* arena, CelValue* result) const { const Reflection* reflection = msg->GetReflection(); @@ -48,36 +50,42 @@ cel_base::Status SelectStep::CreateValueFromField(const google::protobuf::Messag if (field_desc == nullptr) { *result = CreateNoSuchFieldError(arena); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (field_desc->is_map()) { *result = CelValue::CreateMap(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (field_desc->is_repeated()) { *result = CelValue::CreateList(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (test_field_presence_) { *result = CelValue::CreateBool(reflection->HasField(*msg, field_desc)); - return cel_base::OkStatus(); + return absl::OkStatus(); } return CreateValueFromSingleField(msg, field_desc, arena, result); } -cel_base::Status SelectStep::Evaluate(ExecutionFrame* frame) const { +absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "No arguments supplied for Select-type expression"); } - CelValue arg = frame->value_stack().Peek(); + const CelValue& arg = frame->value_stack().Peek(); + const AttributeTrail& trail = frame->value_stack().PeekAttribute(); + + CelValue result; + AttributeTrail result_trail; // Non-empty select path - check if value mapped to unknown. bool unknown_value = false; + // TODO(issues/41) deprecate this path after proper support of unknown is + // implemented if (!select_path_.empty()) { unknown_value = frame->activation().IsPathUnknown(select_path_); } @@ -87,24 +95,36 @@ cel_base::Status SelectStep::Evaluate(ExecutionFrame* frame) const { case CelValue::Type::kMessage: { const google::protobuf::Message* msg = arg.MessageOrDie(); + if (frame->enable_unknowns()) { + result_trail = trail.Step(&field_, frame->arena()); + if (frame->unknowns_utility().CheckForUnknown(result_trail, + /*use_partial=*/false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({result_trail.attribute()})); + result = CelValue::CreateUnknownSet(unknown_set); + frame->value_stack().PopAndPush(result, result_trail); + return absl::OkStatus(); + } + } + if (msg == nullptr) { CelValue error_value = CreateErrorValue(frame->arena(), "Message is NULL"); - frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + frame->value_stack().PopAndPush(error_value, result_trail); + return absl::OkStatus(); } if (unknown_value) { CelValue error_value = CreateUnknownValueError(frame->arena(), select_path_); - frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + frame->value_stack().PopAndPush(error_value, result_trail); + return absl::OkStatus(); } - cel_base::Status status = CreateValueFromField(msg, frame->arena(), &arg); + absl::Status status = CreateValueFromField(msg, frame->arena(), &result); if (status.ok()) { - frame->value_stack().PopAndPush(arg); + frame->value_stack().PopAndPush(result, result_trail); } return status; @@ -115,42 +135,57 @@ cel_base::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (cel_map == nullptr) { CelValue error_value = CreateErrorValue(frame->arena(), "Map is NULL"); frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (unknown_value) { CelValue error_value = CreateErrorValue( frame->arena(), absl::StrCat("Unknown value ", select_path_)); frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + return absl::OkStatus(); } auto lookup_result = (*cel_map)[CelValue::CreateString(&field_)]; // Test only Select expression. if (test_field_presence_) { - arg = CelValue::CreateBool(lookup_result.has_value()); - frame->value_stack().PopAndPush(arg); - return cel_base::OkStatus(); + result = CelValue::CreateBool(lookup_result.has_value()); + frame->value_stack().PopAndPush(result); + return absl::OkStatus(); + } + + if (frame->enable_unknowns()) { + result_trail = trail.Step(&field_, frame->arena()); + if (frame->unknowns_utility().CheckForUnknown(result_trail, false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({result_trail.attribute()})); + result = CelValue::CreateUnknownSet(unknown_set); + frame->value_stack().PopAndPush(result, result_trail); + return absl::OkStatus(); + } } // If object is not found, we return Error, per CEL specification. if (lookup_result) { - arg = lookup_result.value(); + result = lookup_result.value(); } else { - arg = CreateNoSuchKeyError(frame->arena(), field_); + result = CreateNoSuchKeyError(frame->arena(), field_); } - frame->value_stack().PopAndPush(arg); + frame->value_stack().PopAndPush(result, result_trail); - return ::cel_base::OkStatus(); + return absl::OkStatus(); + } + case CelValue::Type::kUnknownSet: { + // Parent is unknown already, bubble it up. + return absl::OkStatus(); } case CelValue::Type::kError: { // If argument is CelError, we propagate it forward. // It is already on the top of the stack. - return ::cel_base::OkStatus(); + return absl::OkStatus(); } default: - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Applying SELECT to non-message type"); } } diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 4baf6b398..2a5cfa8f2 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -5,8 +5,11 @@ #include "gtest/gtest.h" #include "eval/eval/container_backed_map_impl.h" #include "eval/eval/ident_step.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/unknown_attribute_set.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -24,7 +27,8 @@ using google::api::expr::v1alpha1::Expr; cel_base::StatusOr RunExpression(const CelValue target, absl::string_view field, bool test, google::protobuf::Arena* arena, - absl::string_view unknown_path) { + absl::string_view unknown_path, + bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; @@ -50,7 +54,8 @@ cel_base::StatusOr RunExpression(const CelValue target, path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + enable_unknowns); Activation activation; activation.InsertValue("target", target); @@ -60,50 +65,57 @@ cel_base::StatusOr RunExpression(const CelValue target, cel_base::StatusOr RunExpression(const TestMessage* message, absl::string_view field, bool test, google::protobuf::Arena* arena, - absl::string_view unknown_path) { + absl::string_view unknown_path, + bool enable_unknowns) { return RunExpression(CelValue::CreateMessage(message, arena), field, test, - arena, unknown_path); + arena, unknown_path, enable_unknowns); } cel_base::StatusOr RunExpression(const TestMessage* message, absl::string_view field, bool test, - google::protobuf::Arena* arena) { - return RunExpression(message, field, test, arena, ""); + google::protobuf::Arena* arena, + bool enable_unknowns) { + return RunExpression(message, field, test, arena, "", enable_unknowns); } cel_base::StatusOr RunExpression(const CelMap* map_value, absl::string_view field, bool test, google::protobuf::Arena* arena, - absl::string_view unknown_path) { + absl::string_view unknown_path, + bool enable_unknowns) { return RunExpression(CelValue::CreateMap(map_value), field, test, arena, - unknown_path); + unknown_path, enable_unknowns); } cel_base::StatusOr RunExpression(const CelMap* map_value, absl::string_view field, bool test, - google::protobuf::Arena* arena) { - return RunExpression(map_value, field, test, arena, ""); + google::protobuf::Arena* arena, + bool enable_unknowns) { + return RunExpression(map_value, field, test, arena, "", enable_unknowns); } -TEST(SelectStepTest, SelectMessageIsNull) { +class SelectStepTest : public testing::TestWithParam {}; + +TEST_P(SelectStepTest, SelectMessageIsNull) { google::protobuf::Arena arena; auto run_status = RunExpression(static_cast(nullptr), - "bool_value", true, &arena); - ASSERT_TRUE(run_status.ok()); + "bool_value", true, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); } -TEST(SelectStepTest, PresenseIsFalseTest) { +TEST_P(SelectStepTest, PresenseIsFalseTest) { TestMessage message; google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", true, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", true, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -111,14 +123,15 @@ TEST(SelectStepTest, PresenseIsFalseTest) { EXPECT_EQ(result.BoolOrDie(), false); } -TEST(SelectStepTest, PresenseIsTrueTest) { +TEST_P(SelectStepTest, PresenseIsTrueTest) { TestMessage message; message.set_bool_value(true); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", true, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", true, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -126,7 +139,7 @@ TEST(SelectStepTest, PresenseIsTrueTest) { EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, MapPresenseIsFalseTest) { +TEST_P(SelectStepTest, MapPresenseIsFalseTest) { std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; @@ -136,14 +149,15 @@ TEST(SelectStepTest, MapPresenseIsFalseTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(map_value.get(), "key2", true, &arena); + auto run_status = + RunExpression(map_value.get(), "key2", true, &arena, GetParam()); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } -TEST(SelectStepTest, MapPresenseIsTrueTest) { +TEST_P(SelectStepTest, MapPresenseIsTrueTest) { std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; @@ -153,34 +167,56 @@ TEST(SelectStepTest, MapPresenseIsTrueTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(map_value.get(), "key1", true, &arena); + auto run_status = + RunExpression(map_value.get(), "key1", true, &arena, GetParam()); + CelValue result = run_status.ValueOrDie(); + + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.BoolOrDie(), true); +} + +TEST(SelectStepTest, MapPresenseIsTrueWithUnknownTest) { + UnknownSet unknown_set; + std::string key1 = "key1"; + std::vector> key_values{ + {CelValue::CreateString(&key1), + CelValue::CreateUnknownSet(&unknown_set)}}; + + auto map_value = CreateContainerBackedMap( + absl::Span>(key_values)); + + google::protobuf::Arena arena; + + auto run_status = RunExpression(map_value.get(), "key1", true, &arena, true); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, FieldIsNotPresentInProtoTest) { +TEST_P(SelectStepTest, FieldIsNotPresentInProtoTest) { TestMessage message; google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "fake_field", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "fake_field", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(result.ErrorOrDie()->code(), Eq(cel_base::StatusCode::kNotFound)); + EXPECT_THAT(result.ErrorOrDie()->code(), Eq(absl::StatusCode::kNotFound)); } -TEST(SelectStepTest, FieldIsNotSetTest) { +TEST_P(SelectStepTest, FieldIsNotSetTest) { TestMessage message; google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -188,14 +224,15 @@ TEST(SelectStepTest, FieldIsNotSetTest) { EXPECT_EQ(result.BoolOrDie(), false); } -TEST(SelectStepTest, SimpleBoolTest) { +TEST_P(SelectStepTest, SimpleBoolTest) { TestMessage message; message.set_bool_value(true); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -203,14 +240,15 @@ TEST(SelectStepTest, SimpleBoolTest) { EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, SimpleInt32Test) { +TEST_P(SelectStepTest, SimpleInt32Test) { TestMessage message; message.set_int32_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "int32_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "int32_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -218,14 +256,15 @@ TEST(SelectStepTest, SimpleInt32Test) { EXPECT_EQ(result.Int64OrDie(), 1); } -TEST(SelectStepTest, SimpleInt64Test) { +TEST_P(SelectStepTest, SimpleInt64Test) { TestMessage message; message.set_int64_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "int64_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "int64_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -233,14 +272,15 @@ TEST(SelectStepTest, SimpleInt64Test) { EXPECT_EQ(result.Int64OrDie(), 1); } -TEST(SelectStepTest, SimpleUInt32Test) { +TEST_P(SelectStepTest, SimpleUInt32Test) { TestMessage message; message.set_uint32_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "uint32_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "uint32_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -248,14 +288,15 @@ TEST(SelectStepTest, SimpleUInt32Test) { EXPECT_EQ(result.Uint64OrDie(), 1); } -TEST(SelectStepTest, SimpleUint64Test) { +TEST_P(SelectStepTest, SimpleUint64Test) { TestMessage message; message.set_uint64_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "uint64_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "uint64_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -263,15 +304,16 @@ TEST(SelectStepTest, SimpleUint64Test) { EXPECT_EQ(result.Uint64OrDie(), 1); } -TEST(SelectStepTest, SimpleStringTest) { +TEST_P(SelectStepTest, SimpleStringTest) { TestMessage message; std::string value = "test"; message.set_string_value(value); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "string_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "string_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -280,15 +322,16 @@ TEST(SelectStepTest, SimpleStringTest) { } -TEST(SelectStepTest, SimpleBytesTest) { +TEST_P(SelectStepTest, SimpleBytesTest) { TestMessage message; std::string value = "test"; message.set_bytes_value(value); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bytes_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bytes_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -296,7 +339,7 @@ TEST(SelectStepTest, SimpleBytesTest) { EXPECT_EQ(result.BytesOrDie().value(), "test"); } -TEST(SelectStepTest, SimpleMessageTest) { +TEST_P(SelectStepTest, SimpleMessageTest) { TestMessage message; TestMessage* message2 = message.mutable_message_value(); @@ -305,8 +348,9 @@ TEST(SelectStepTest, SimpleMessageTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "message_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "message_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -314,15 +358,16 @@ TEST(SelectStepTest, SimpleMessageTest) { EXPECT_THAT(*message2, EqualsProto(*result.MessageOrDie())); } -TEST(SelectStepTest, SimpleEnumTest) { +TEST_P(SelectStepTest, SimpleEnumTest) { TestMessage message; message.set_enum_value(TestMessage::TEST_ENUM_1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "enum_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "enum_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -330,7 +375,7 @@ TEST(SelectStepTest, SimpleEnumTest) { EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } -TEST(SelectStepTest, SimpleListTest) { +TEST_P(SelectStepTest, SimpleListTest) { TestMessage message; message.add_int32_list(1); @@ -338,8 +383,9 @@ TEST(SelectStepTest, SimpleListTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "int32_list", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "int32_list", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -350,7 +396,7 @@ TEST(SelectStepTest, SimpleListTest) { EXPECT_THAT(cel_list->size(), Eq(2)); } -TEST(SelectStepTest, SimpleMapTest) { +TEST_P(SelectStepTest, SimpleMapTest) { TestMessage message; auto map_field = message.mutable_string_int32_map(); (*map_field)["test0"] = 1; @@ -358,8 +404,9 @@ TEST(SelectStepTest, SimpleMapTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "string_int32_map", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "string_int32_map", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -370,7 +417,7 @@ TEST(SelectStepTest, SimpleMapTest) { EXPECT_THAT(cel_map->size(), Eq(2)); } -TEST(SelectStepTest, MapSimpleInt32Test) { +TEST_P(SelectStepTest, MapSimpleInt32Test) { std::string key1 = "key1"; std::string key2 = "key2"; std::vector> key_values{ @@ -382,8 +429,9 @@ TEST(SelectStepTest, MapSimpleInt32Test) { google::protobuf::Arena arena; - auto run_status = RunExpression(map_value.get(), "key1", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(map_value.get(), "key1", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -392,7 +440,7 @@ TEST(SelectStepTest, MapSimpleInt32Test) { } // Test Select behavior, when expression to select from is an Error. -TEST(SelectStepTest, CelErrorAsArgument) { +TEST_P(SelectStepTest, CelErrorAsArgument) { ExecutionPath path; Expr dummy_expr; @@ -416,19 +464,20 @@ TEST(SelectStepTest, CelErrorAsArgument) { CelError error; google::protobuf::Arena arena; - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + GetParam()); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); auto status = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto result = status.ValueOrDie(); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie(), Eq(&error)); } -TEST(SelectStepTest, UnknownValueProducesError) { +TEST_P(SelectStepTest, UnknownValueProducesError) { TestMessage message; message.set_bool_value(true); google::protobuf::Arena arena; @@ -447,18 +496,19 @@ TEST(SelectStepTest, UnknownValueProducesError) { auto step1_status = CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + GetParam()); Activation activation; activation.InsertValue("message", CelValue::CreateMessage(&message, &arena)); auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(eval_status0.ok()); + ASSERT_OK(eval_status0); CelValue result = eval_status0.ValueOrDie(); @@ -471,7 +521,7 @@ TEST(SelectStepTest, UnknownValueProducesError) { activation.set_unknown_paths(mask); auto eval_status1 = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(eval_status1.ok()); + ASSERT_OK(eval_status1); result = eval_status1.ValueOrDie(); @@ -481,6 +531,124 @@ TEST(SelectStepTest, UnknownValueProducesError) { Eq(std::set({"message.bool_value"}))); } +TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { + TestMessage message; + message.set_bool_value(true); + google::protobuf::Arena arena; + ExecutionPath path; + + Expr dummy_expr; + + auto select = dummy_expr.mutable_select_expr(); + select->set_field("bool_value"); + select->set_test_only(false); + Expr* expr0 = select->mutable_operand(); + + auto ident = expr0->mutable_ident_expr(); + ident->set_name("message"); + auto step0_status = CreateIdentStep(ident, expr0->id()); + auto step1_status = + CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, true); + + { + std::vector unknown_patterns; + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.BoolOrDie(), true); + } + + const std::string kSegmentCorrect1 = "bool_value"; + const std::string kSegmentIncorrect = "message_value"; + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern("message", {})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); + } + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern( + "message", {CelAttributeQualifierPattern::Create( + CelValue::CreateString(&kSegmentCorrect1))})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); + } + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern( + "message", {CelAttributeQualifierPattern::CreateWildcard()})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); + } + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern( + "message", {CelAttributeQualifierPattern::Create( + CelValue::CreateString(&kSegmentIncorrect))})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.BoolOrDie(), true); + } +} + +INSTANTIATE_TEST_SUITE_P(SelectStepTest, SelectStepTest, testing::Bool()); } // namespace } // namespace runtime } // namespace expr diff --git a/eval/eval/unknowns_utility.cc b/eval/eval/unknowns_utility.cc new file mode 100644 index 000000000..4c9d122fe --- /dev/null +++ b/eval/eval/unknowns_utility.cc @@ -0,0 +1,96 @@ +#include "eval/eval/unknowns_utility.h" + +#include "absl/status/status.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" +#include "base/statusor.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +using google::protobuf::Arena; + +// Checks whether particular corresponds to any patterns that define unknowns. +bool UnknownsUtility::CheckForUnknown(const AttributeTrail& trail, + bool use_partial) const { + if (trail.empty()) { + return false; + } + for (const auto& pattern : *unknown_patterns_) { + auto current_match = pattern.IsMatch(*trail.attribute()); + if (current_match == CelAttributePattern::MatchType::FULL || + (use_partial && + current_match == CelAttributePattern::MatchType::PARTIAL)) { + return true; + } + } + return false; +} + +// Creates merged UnknownAttributeSet. +// Scans over the args collection, merges any UnknownSets found in +// it together with initial_set (if initial_set is not null). +// Returns pointer to merged set or nullptr, if there were no sets to merge. +const UnknownSet* UnknownsUtility::MergeUnknowns( + absl::Span args, const UnknownSet* initial_set) const { + const UnknownSet* result = initial_set; + + for (const auto& value : args) { + if (!value.IsUnknownSet()) continue; + + auto current_set = value.UnknownSetOrDie(); + if (result == nullptr) { + result = current_set; + } else { + result = Arena::Create(arena_, *result, *current_set); + } + } + + return result; +} + +// Creates merged UnknownAttributeSet. +// Scans over the args collection, determines if there matches to unknown +// patterns, merges attributes together with those from initial_set +// (if initial_set is not null). +// Returns pointer to merged set or nullptr, if there were no sets to merge. +UnknownAttributeSet UnknownsUtility::CheckForUnknowns( + absl::Span args, bool use_partial) const { + std::vector unknown_attrs; + + for (auto trail : args) { + if (CheckForUnknown(trail, use_partial)) { + unknown_attrs.push_back(trail.attribute()); + } + } + + return UnknownAttributeSet(unknown_attrs); +} + +// Creates merged UnknownAttributeSet. +// Merges together attributes from UnknownAttributeSets found in the args +// collection, attributes from attr that match unknown pattern +// patterns, and attributes from initial_set +// (if initial_set is not null). +// Returns pointer to merged set or nullptr, if there were no sets to merge. +const UnknownSet* UnknownsUtility::MergeUnknowns( + absl::Span args, absl::Span attrs, + const UnknownSet* initial_set, bool use_partial) const { + UnknownAttributeSet attr_set = CheckForUnknowns(attrs, use_partial); + if (!attr_set.attributes().empty()) { + if (initial_set != nullptr) { + initial_set = + Arena::Create(arena_, *initial_set, UnknownSet(attr_set)); + } else { + initial_set = Arena::Create(arena_, attr_set); + } + } + return MergeUnknowns(args, initial_set); +} +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/unknowns_utility.h b/eval/eval/unknowns_utility.h new file mode 100644 index 000000000..5f0f7df6a --- /dev/null +++ b/eval/eval/unknowns_utility.h @@ -0,0 +1,68 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ + +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "absl/types/optional.h" +#include "eval/eval/attribute_trail.h" +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Helper class for handling unknowns logic. Provides helpers for merging +// unknown sets from arguments on the stack and for identifying unknown +// attributes based on the patterns for a given Evaluation. +class UnknownsUtility { + public: + UnknownsUtility(const std::vector* unknown_patterns, + google::protobuf::Arena* arena) + : unknown_patterns_(unknown_patterns), arena_(arena) {} + + // Checks whether particular corresponds to any patterns that define unknowns. + bool CheckForUnknown(const AttributeTrail& trail, bool use_partial) const; + + // Creates merged UnknownAttributeSet. + // Scans over the args collection, determines if there matches to unknown + // patterns and returns the (possibly empty) collection. + UnknownAttributeSet CheckForUnknowns(absl::Span args, + bool use_partial) const; + + // Creates merged UnknownSet. + // Scans over the args collection, merges any UnknownAttributeSets found in + // it together with initial_set (if initial_set is not null). + // Returns pointer to merged set or nullptr, if there were no sets to merge. + const UnknownSet* MergeUnknowns(absl::Span args, + const UnknownSet* initial_set) const; + + // Creates merged UnknownSet. + // Merges together attributes from UnknownSets found in the args + // collection, attributes from attr that match unknown pattern + // patterns, and attributes from initial_set + // (if initial_set is not null). + // Returns pointer to merged set or nullptr, if there were no sets to merge. + const UnknownSet* MergeUnknowns(absl::Span args, + absl::Span attrs, + const UnknownSet* initial_set, + bool use_partial) const; + + private: + const std::vector* unknown_patterns_; + google::protobuf::Arena* arena_; +}; +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ diff --git a/eval/eval/unknowns_utility_test.cc b/eval/eval/unknowns_utility_test.cc new file mode 100644 index 000000000..52fd546bb --- /dev/null +++ b/eval/eval/unknowns_utility_test.cc @@ -0,0 +1,145 @@ +#include "eval/eval/unknowns_utility.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +using testing::Eq; +using testing::IsNull; +using testing::NotNull; +using testing::SizeIs; +using testing::UnorderedPointwise; + +TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { + google::protobuf::Arena arena; + std::vector patterns = { + CelAttributePattern("unknown0", {CelAttributeQualifierPattern::Create( + CelValue::CreateInt64(1))}), + CelAttributePattern("unknown0", {CelAttributeQualifierPattern::Create( + CelValue::CreateInt64(2))}), + CelAttributePattern("unknown1", {}), + CelAttributePattern("unknown2", {}), + }; + UnknownsUtility utility(&patterns, &arena); + // no match for void trail + ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), true)); + ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), false)); + + google::api::expr::v1alpha1::Expr unknown_expr0; + unknown_expr0.mutable_ident_expr()->set_name("unknown0"); + + AttributeTrail unknown_trail0(unknown_expr0, &arena); + + { ASSERT_FALSE(utility.CheckForUnknown(unknown_trail0, false)); } + + { ASSERT_TRUE(utility.CheckForUnknown(unknown_trail0, true)); } + + { + ASSERT_TRUE(utility.CheckForUnknown( + unknown_trail0.Step( + CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + false)); + } + + { + ASSERT_TRUE(utility.CheckForUnknown( + unknown_trail0.Step( + CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + true)); + } +} + +TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { + google::protobuf::Arena arena; + + google::api::expr::v1alpha1::Expr unknown_expr0; + unknown_expr0.mutable_ident_expr()->set_name("unknown0"); + + google::api::expr::v1alpha1::Expr unknown_expr1; + unknown_expr1.mutable_ident_expr()->set_name("unknown1"); + + google::api::expr::v1alpha1::Expr unknown_expr2; + unknown_expr2.mutable_ident_expr()->set_name("unknown2"); + + std::vector patterns; + + CelAttribute attribute0(unknown_expr0, {}); + CelAttribute attribute1(unknown_expr1, {}); + CelAttribute attribute2(unknown_expr2, {}); + + UnknownsUtility utility(&patterns, &arena); + + UnknownSet unknown_set0(UnknownAttributeSet({&attribute0})); + UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); + UnknownSet unknown_set2(UnknownAttributeSet({&attribute1, &attribute2})); + std::vector values = { + CelValue::CreateUnknownSet(&unknown_set0), + CelValue::CreateUnknownSet(&unknown_set1), + CelValue::CreateBool(true), + CelValue::CreateInt64(1), + }; + + const UnknownSet* unknown_set = utility.MergeUnknowns(values, nullptr); + ASSERT_THAT(unknown_set, NotNull()); + ASSERT_THAT(unknown_set->unknown_attributes().attributes(), + UnorderedPointwise(Eq(), std::vector{ + &attribute0, &attribute1})); + + unknown_set = utility.MergeUnknowns(values, &unknown_set2); + ASSERT_THAT(unknown_set, NotNull()); + ASSERT_THAT( + unknown_set->unknown_attributes().attributes(), + UnorderedPointwise(Eq(), std::vector{ + &attribute0, &attribute1, &attribute2})); +} + +TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { + google::protobuf::Arena arena; + + std::vector patterns = { + CelAttributePattern("unknown0", + {CelAttributeQualifierPattern::CreateWildcard()}), + }; + + google::api::expr::v1alpha1::Expr unknown_expr0; + unknown_expr0.mutable_ident_expr()->set_name("unknown0"); + + google::api::expr::v1alpha1::Expr unknown_expr1; + unknown_expr1.mutable_ident_expr()->set_name("unknown1"); + + AttributeTrail trail0(unknown_expr0, &arena); + AttributeTrail trail1(unknown_expr1, &arena); + + CelAttribute attribute1(unknown_expr1, {}); + UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); + + UnknownsUtility utility(&patterns, &arena); + + UnknownSet unknown_attr_set(utility.CheckForUnknowns( + { + AttributeTrail(), // To make sure we handle empty trail gracefully. + trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(1)), + &arena), + trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(2)), + &arena), + }, + false)); + + UnknownSet unknown_set(unknown_set1, unknown_attr_set); + + ASSERT_THAT(unknown_set.unknown_attributes().attributes(), SizeIs(3)); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/BUILD b/eval/public/BUILD index 878c2712e..c01383f73 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -26,9 +26,10 @@ cc_library( copts = ["-std=c++14"], deps = [ ":cel_value_internal", - "//base:status", + "//base:statusor", "//internal:proto_util", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", @@ -47,7 +48,7 @@ cc_library( deps = [ ":cel_value", ":cel_value_internal", - "//base:status", + "//base:statusor", "//internal:proto_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -77,9 +78,7 @@ cc_library( copts = ["-std=c++14"], deps = [ ":cel_attribute", - "//base:status", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", ], ) @@ -93,11 +92,12 @@ cc_library( ], copts = ["-std=c++14"], deps = [ + ":cel_attribute", ":cel_function", ":cel_value", ":cel_value_producer", - "//base:status", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -148,7 +148,8 @@ cc_library( deps = [ ":cel_function", ":cel_function_registry", - "//base:status", + "//base:statusor", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -194,7 +195,7 @@ cc_library( ":cel_function_adapter", ":cel_function_registry", ":cel_options", - "//base:status", + "//base:status_macros", "//eval/eval:container_backed_list_impl", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -328,7 +329,7 @@ cc_library( copts = ["-std=c++14"], deps = [ ":cel_value", - "//base:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -359,6 +360,7 @@ cc_test( deps = [ ":cel_value", ":unknown_attribute_set", + ":unknown_set", "//eval/testutil:test_message_cc_proto", "//testutil:util", "@com_github_google_googletest//:gtest_main", @@ -392,6 +394,7 @@ cc_test( deps = [ ":activation", ":cel_function", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -418,6 +421,7 @@ cc_test( deps = [ ":activation", ":activation_bind_helper", + "//base:status_macros", "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", "@com_google_protobuf//:protobuf", @@ -433,6 +437,7 @@ cc_test( deps = [ ":cel_function", ":cel_function_provider", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -447,6 +452,7 @@ cc_test( ":cel_function", ":cel_function_provider", ":cel_function_registry", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -462,6 +468,7 @@ cc_test( ":cel_function", ":cel_function_adapter", ":cel_value", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -479,6 +486,7 @@ cc_test( ":cel_builtins", ":cel_expr_builder_factory", ":cel_function_registry", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -498,6 +506,7 @@ cc_test( ":cel_function_registry", ":cel_value", ":extension_func_registrar", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -530,8 +539,6 @@ cc_test( ":cel_value", ":unknown_attribute_set", "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "@com_google_protobuf//:protobuf", ], ) @@ -545,7 +552,7 @@ cc_test( deps = [ ":cel_value", ":value_export_util", - "//base:status", + "//base:status_macros", "//eval/eval:container_backed_list_impl", "//eval/eval:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", @@ -555,3 +562,61 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "unknown_function_result_set", + srcs = ["unknown_function_result_set.cc"], + hdrs = ["unknown_function_result_set.h"], + copts = ["-std=c++14"], + deps = [ + ":cel_function", + ":cel_options", + ":cel_value", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + ], +) + +cc_test( + name = "unknown_function_result_set_test", + size = "small", + srcs = [ + "unknown_function_result_set_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":cel_function", + ":cel_value", + ":unknown_function_result_set", + "//eval/eval:container_backed_list_impl", + "//eval/eval:container_backed_map_impl", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "unknown_set", + hdrs = ["unknown_set.h"], + copts = ["-std=c++14"], + deps = [ + ":unknown_attribute_set", + ":unknown_function_result_set", + ], +) + +cc_test( + name = "unknown_set_test", + srcs = ["unknown_set_test.cc"], + copts = ["-std=c++14"], + deps = [ + ":cel_attribute", + ":unknown_attribute_set", + ":unknown_function_result_set", + ":unknown_set", + "@com_github_google_googletest//:gtest_main", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/activation.cc b/eval/public/activation.cc index 3047cc4bd..ecd95ee13 100644 --- a/eval/public/activation.cc +++ b/eval/public/activation.cc @@ -4,10 +4,9 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "eval/public/cel_function.h" -#include "base/canonical_errors.h" -#include "base/status.h" namespace google { namespace api { @@ -26,16 +25,16 @@ absl::optional Activation::FindValue(absl::string_view name, return entry->second.RetrieveValue(arena); } -cel_base::Status Activation::InsertFunction(std::unique_ptr function) { +absl::Status Activation::InsertFunction(std::unique_ptr function) { auto& overloads = function_map_[function->descriptor().name()]; for (const auto& overload : overloads) { if (overload->descriptor().ShapeMatches(function->descriptor())) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Function with same shape already defined in activation"); } } overloads.emplace_back(std::move(function)); - return cel_base::OkStatus(); + return absl::OkStatus(); } std::vector Activation::FindFunctionOverloads( diff --git a/eval/public/activation.h b/eval/public/activation.h index 7d38f0ba8..f6b200ec6 100644 --- a/eval/public/activation.h +++ b/eval/public/activation.h @@ -3,11 +3,13 @@ #include #include +#include #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/util/field_mask_util.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "eval/public/cel_value_producer.h" @@ -39,6 +41,14 @@ class BaseActivation { // Check whether a select path is unknown. virtual bool IsPathUnknown(absl::string_view) const = 0; + // Return FieldMask defining the list of unknown paths. + virtual const google::protobuf::FieldMask unknown_paths() const = 0; + + // Return the collection of attribute patterns that determine "unknown" + // values. + virtual const std::vector& unknown_attribute_patterns() + const = 0; + virtual ~BaseActivation() {} }; @@ -68,7 +78,7 @@ class Activation : public BaseActivation { // Insert a function into the activation (ie a lazily bound function). Returns // a status if the name and shape of the function matches another one that has // already been bound. - cel_base::Status InsertFunction(std::unique_ptr function); + absl::Status InsertFunction(std::unique_ptr function); // Insert value into Activation. void InsertValue(absl::string_view name, const CelValue& value); @@ -98,10 +108,24 @@ class Activation : public BaseActivation { } // Return FieldMask defining the list of unknown paths. - const google::protobuf::FieldMask unknown_paths() const { + const google::protobuf::FieldMask unknown_paths() const override { return unknown_paths_; } + // Sets the collection of attribute patterns that will be recognized as + // "unknown" values during expression evaluation. + void set_unknown_attribute_patterns( + std::vector unknown_attribute_patterns) { + unknown_attribute_patterns_ = std::move(unknown_attribute_patterns); + } + + // Return the collection of attribute patterns that determine "unknown" + // values. + const std::vector& unknown_attribute_patterns() + const override { + return unknown_attribute_patterns_; + } + private: class ValueEntry { public: @@ -140,7 +164,9 @@ class Activation : public BaseActivation { absl::flat_hash_map>> function_map_; + // TODO(issues/41) deprecate when unknowns support is done. google::protobuf::FieldMask unknown_paths_; + std::vector unknown_attribute_patterns_; }; } // namespace runtime diff --git a/eval/public/activation_bind_helper.cc b/eval/public/activation_bind_helper.cc index 0b06272bf..532acc08a 100644 --- a/eval/public/activation_bind_helper.cc +++ b/eval/public/activation_bind_helper.cc @@ -16,17 +16,17 @@ using google::protobuf::Message; using google::protobuf::FieldDescriptor; using google::protobuf::Descriptor; -cel_base::Status CreateValueFromField(const google::protobuf::Message* msg, +absl::Status CreateValueFromField(const google::protobuf::Message* msg, const FieldDescriptor* field_desc, google::protobuf::Arena* arena, CelValue* result) { if (field_desc->is_map()) { *result = CelValue::CreateMap(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } else if (field_desc->is_repeated()) { *result = CelValue::CreateList(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } else { return CreateValueFromSingleField(msg, field_desc, arena, result); } @@ -34,8 +34,8 @@ cel_base::Status CreateValueFromField(const google::protobuf::Message* msg, } // namespace -::cel_base::Status BindProtoToActivation(const Message* message, Arena* arena, - Activation* activation) { +absl::Status BindProtoToActivation(const Message* message, Arena* arena, + Activation* activation) { // TODO(issues/24): Improve the utilities to bind dynamic values as well. const Descriptor* desc = message->GetDescriptor(); const google::protobuf::Reflection* reflection = message->GetReflection(); @@ -56,7 +56,7 @@ ::cel_base::Status BindProtoToActivation(const Message* message, Arena* arena, activation->InsertValue(field_desc->name(), value); } - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/activation_bind_helper.h b/eval/public/activation_bind_helper.h index 40a9c3351..a45ac9833 100644 --- a/eval/public/activation_bind_helper.h +++ b/eval/public/activation_bind_helper.h @@ -31,9 +31,9 @@ namespace runtime { // "name", with string value of "John Doe" // "age", with int value of 42. // -::cel_base::Status BindProtoToActivation(const google::protobuf::Message* message, - google::protobuf::Arena* arena, - Activation* activation); +absl::Status BindProtoToActivation(const google::protobuf::Message* message, + google::protobuf::Arena* arena, + Activation* activation); } // namespace runtime } // namespace expr diff --git a/eval/public/activation_bind_helper_test.cc b/eval/public/activation_bind_helper_test.cc index 7c00dcee1..d549d77a4 100644 --- a/eval/public/activation_bind_helper_test.cc +++ b/eval/public/activation_bind_helper_test.cc @@ -1,10 +1,10 @@ #include "eval/public/activation_bind_helper.h" -#include "eval/public/activation.h" - -#include "eval/testutil/test_message.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/public/activation.h" +#include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -21,7 +21,7 @@ TEST(ActivationBindHelperTest, TestSingleBoolBind) { Activation activation; - ASSERT_TRUE(BindProtoToActivation(&message, &arena, &activation).ok()); + ASSERT_OK(BindProtoToActivation(&message, &arena, &activation)); auto result = activation.FindValue("bool_value", &arena); @@ -41,7 +41,7 @@ TEST(ActivationBindHelperTest, TestSingleInt32Bind) { Activation activation; - ASSERT_TRUE(BindProtoToActivation(&message, &arena, &activation).ok()); + ASSERT_OK(BindProtoToActivation(&message, &arena, &activation)); auto result = activation.FindValue("int32_value", &arena); diff --git a/eval/public/activation_test.cc b/eval/public/activation_test.cc index 6094f5b38..4927cf353 100644 --- a/eval/public/activation_test.cc +++ b/eval/public/activation_test.cc @@ -3,6 +3,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "eval/public/cel_function.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -33,10 +34,10 @@ class ConstCelFunction : public CelFunction { explicit ConstCelFunction(const CelFunctionDescriptor& desc) : CelFunction(desc) {} - cel_base::Status Evaluate(absl::Span args, CelValue* output, + absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { *output = CelValue::CreateInt64(42); - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -110,7 +111,7 @@ TEST(ActivationTest, CheckInsertFunction) { Activation activation; auto insert_status = activation.InsertFunction( std::make_unique("ConstFunc")); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT(overloads, @@ -118,7 +119,7 @@ TEST(ActivationTest, CheckInsertFunction) { &CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("ConstFunc"))))); - cel_base::Status status = activation.InsertFunction( + absl::Status status = activation.InsertFunction( std::make_unique("ConstFunc")); EXPECT_THAT(std::string(status.message()), @@ -134,10 +135,10 @@ TEST(ActivationTest, CheckRemoveFunction) { auto insert_status = activation.InsertFunction(std::make_unique( CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kInt64}})); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); insert_status = activation.InsertFunction(std::make_unique( CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kUint64}})); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT( diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 2a723f933..91578b2fe 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -4,13 +4,13 @@ #include "google/protobuf/util/time_util.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "eval/eval/container_backed_list_impl.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "re2/re2.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -261,9 +261,9 @@ CelValue Equal(Arena* arena, CelValue t1, CelValue t2) { // // Registers all equality functions for template parameters type. template -::cel_base::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registry) { +absl::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registry) { // Inequality - ::cel_base::Status status = + absl::Status status = FunctionAdapter::CreateAndRegister( builtin::kInequal, false, Inequal, registry); if (!status.ok()) return status; @@ -276,9 +276,8 @@ ::cel_base::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registr // Registers all comparison functions for template parameter type. template -::cel_base::Status RegisterComparisonFunctionsForType( - CelFunctionRegistry* registry) { - ::cel_base::Status status = RegisterEqualityFunctionsForType(registry); +absl::Status RegisterComparisonFunctionsForType(CelFunctionRegistry* registry) { + absl::Status status = RegisterEqualityFunctionsForType(registry); if (!status.ok()) return status; // Less than @@ -301,7 +300,7 @@ ::cel_base::Status RegisterComparisonFunctionsForType( builtin::kGreaterOrEqual, false, GreaterThanOrEqual, registry); if (!status.ok()) return status; - return ::cel_base::OkStatus(); + return absl::OkStatus(); } // Template functions providing arithmetic operations @@ -382,9 +381,8 @@ CelValue Modulo(Arena* arena, uint64_t value, uint64_t value2) { // Helper method // Registers all arithmetic functions for template parameter type. template -::cel_base::Status RegisterArithmeticFunctionsForType( - CelFunctionRegistry* registry) { - cel_base::Status status = FunctionAdapter::CreateAndRegister( +absl::Status RegisterArithmeticFunctionsForType(CelFunctionRegistry* registry) { + absl::Status status = FunctionAdapter::CreateAndRegister( builtin::kAdd, false, Add, registry); if (!status.ok()) return status; @@ -488,19 +486,19 @@ const CelList* ConcatList(Arena* arena, const CelList* value1, } // Timestamp -const cel_base::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, +const absl::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, absl::TimeZone::CivilInfo* breakdown) { absl::TimeZone time_zone; if (!tz.empty()) { bool found = absl::LoadTimeZone(std::string(tz), &time_zone); if (!found) { - return cel_base::InvalidArgumentError("Invalid timezone"); + return absl::InvalidArgumentError("Invalid timezone"); } } *breakdown = time_zone.At(timestamp); - return cel_base::OkStatus(); + return absl::OkStatus(); } CelValue GetTimeBreakdownPart( @@ -523,7 +521,7 @@ CelValue CreateTimestampFromString(Arena* arena, if (!absl::ParseTime(absl::RFC3339_full, std::string(time_str.value()), &ts, nullptr)) { return CreateErrorValue(arena, "String to Timestamp conversion failed", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } return CelValue::CreateTimestamp(ts); } @@ -570,7 +568,7 @@ CelValue GetDayOfWeek(Arena* arena, absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart( arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - absl::Weekday weekday = absl::GetWeekday(absl::CivilDay(breakdown.cs)); + absl::Weekday weekday = absl::GetWeekday(breakdown.cs); // get day of week from the date in UTC, zero-based, zero for Sunday, // based on GetDayOfWeek CEL function definition. @@ -615,7 +613,7 @@ CelValue CreateDurationFromString(Arena* arena, absl::Duration d; if (!absl::ParseDuration(std::string(dur_str.value()), &d)) { return CreateErrorValue(arena, "String to Duration conversion failed", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } return CelValue::CreateDuration(d); @@ -654,28 +652,9 @@ bool StringStartsWith(Arena*, CelValue::StringHolder value, return absl::StartsWith(value.value(), prefix.value()); } -} // namespace - -::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, - const InterpreterOptions& options) { - // logical NOT - cel_base::Status status = FunctionAdapter::CreateAndRegister( - builtin::kNot, false, [](Arena*, bool value) -> bool { return !value; }, - registry); - if (!status.ok()) return status; - - // Negation group - status = FunctionAdapter::CreateAndRegister( - builtin::kNeg, false, [](Arena*, int64_t value) -> int64_t { return -value; }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kNeg, false, - [](Arena*, double value) -> double { return -value; }, registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); +absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + auto status = RegisterComparisonFunctionsForType(registry); if (!status.ok()) return status; status = RegisterComparisonFunctionsForType(registry); @@ -708,6 +687,247 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, status = RegisterEqualityFunctionsForType(registry); if (!status.ok()) return status; + return absl::OkStatus(); +} + +absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + auto status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringContains, false, StringContains, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringContains, true, StringContains, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringEndsWith, false, StringEndsWith, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringEndsWith, true, StringEndsWith, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringStartsWith, false, StringStartsWith, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringStartsWith, true, StringStartsWith, + registry); + if (!status.ok()) return status; + + return absl::OkStatus(); +} + +absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + // Timestamp + // + // timestamp() conversion from string.. + auto status = + FunctionAdapter::CreateAndRegister( + builtin::kTimestamp, false, CreateTimestampFromString, registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kFullYear, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetFullYear(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kFullYear, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetFullYear(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kMonth, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMonth(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kMonth, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetMonth(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDayOfYear, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfYear(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDayOfYear, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDayOfYear(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDayOfMonth, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfMonth(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDayOfMonth, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDayOfMonth(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDate, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDate(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDate, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDate(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDayOfWeek, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfWeek(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDayOfWeek, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDayOfWeek(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kHours, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetHours(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kHours, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetHours(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kMinutes, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMinutes(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kMinutes, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetMinutes(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kSeconds, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetSeconds(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kSeconds, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetSeconds(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kMilliseconds, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMilliseconds(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kMilliseconds, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetMilliseconds(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + // logical NOT + absl::Status status = FunctionAdapter::CreateAndRegister( + builtin::kNot, false, [](Arena*, bool value) -> bool { return !value; }, + registry); + if (!status.ok()) return status; + + // Negation group + status = FunctionAdapter::CreateAndRegister( + builtin::kNeg, false, [](Arena*, int64_t value) -> int64_t { return -value; }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kNeg, false, + [](Arena*, double value) -> double { return -value; }, registry); + if (!status.ok()) return status; + + status = RegisterComparisonFunctions(registry, options); + if (!status.ok()) return status; + // Logical AND // This implementation is used when short-circuiting is off. status = FunctionAdapter::CreateAndRegister( @@ -1128,11 +1348,11 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, RE2 re2(regex.value().data()); if (max_size > 0 && re2.ProgramSize() > max_size) { return CreateErrorValue(arena, "exceeded RE2 max program size", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } if (!re2.ok()) { return CreateErrorValue(arena, "invalid_argument", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } return CelValue::CreateBool(RE2::PartialMatch(re2::StringPiece(target.value().data(), target.value().size()), re2)); }; @@ -1151,40 +1371,7 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; } - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, false, StringContains, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, true, StringContains, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, false, StringEndsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, true, StringEndsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, false, StringStartsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, true, StringStartsWith, - registry); + status = RegisterStringFunctions(registry, options); if (!status.ok()) return status; // Modulo @@ -1196,171 +1383,7 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, builtin::kModulo, false, Modulo, registry); if (!status.ok()) return status; - // Timestamp - // - // timestamp() conversion from string.. - status = FunctionAdapter::CreateAndRegister( - builtin::kTimestamp, false, CreateTimestampFromString, registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kFullYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetFullYear(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kFullYear, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetFullYear(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMonth(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMonth, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMonth(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfYear(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfYear, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfYear(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfMonth(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfMonth, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfMonth(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDate, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDate(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDate, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDate(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfWeek, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfWeek(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfWeek, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfWeek(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetHours(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetHours(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMinutes(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMinutes(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetSeconds(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetSeconds(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMilliseconds(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMilliseconds(arena, ts, ""); - }, - registry); + status = RegisterTimestampFunctions(registry, options); if (!status.ok()) return status; // type conversion to int @@ -1385,6 +1408,20 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInt, false, + [](Arena* arena, CelValue::StringHolder s) { + int64_t result; + if (absl::SimpleAtoi(s.value(), &result)) { + return CelValue::CreateInt64(result); + } else { + return CreateErrorValue(arena, "doesn't convert to a string", + absl::StatusCode::kInvalidArgument); + } + }, + registry); + if (!status.ok()) return status; + // duration // duration() conversion from string.. @@ -1473,7 +1510,7 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; } - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/builtin_func_registrar.h b/eval/public/builtin_func_registrar.h index 332487d87..2cf906857 100644 --- a/eval/public/builtin_func_registrar.h +++ b/eval/public/builtin_func_registrar.h @@ -10,7 +10,7 @@ namespace api { namespace expr { namespace runtime { -cel_base::Status RegisterBuiltinFunctions( +absl::Status RegisterBuiltinFunctions( CelFunctionRegistry* registry, const InterpreterOptions& options = InterpreterOptions()); diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 183a7840c..497ede135 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -8,6 +8,7 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_function_registry.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -31,7 +32,7 @@ class BuiltinsTest : public ::testing::Test { protected: BuiltinsTest() {} - void SetUp() override { ASSERT_TRUE(RegisterBuiltinFunctions(®istry_).ok()); } + void SetUp() override { ASSERT_OK(RegisterBuiltinFunctions(®istry_)); } // Helper method. Looks up in registry and tests comparison operation. void PerformRun(absl::string_view operation, absl::optional target, @@ -68,18 +69,18 @@ class BuiltinsTest : public ::testing::Test { CreateCelExpressionBuilder(options); // Builtin registration. - ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry(), options).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Create CelExpression from AST (Expr object). auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - ASSERT_TRUE(cel_expression_status.ok()); + ASSERT_OK(cel_expression_status); auto cel_expression = std::move(cel_expression_status.ValueOrDie()); auto eval_status = cel_expression->Evaluate(activation, &arena_); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); *result = eval_status.ValueOrDie(); } @@ -895,7 +896,7 @@ TEST_F(BuiltinsTest, MapInt64Index) { ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), - Eq(cel_base::StatusCode::kNotFound)); + Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } @@ -924,7 +925,7 @@ TEST_F(BuiltinsTest, MapUint64Index) { ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), - Eq(cel_base::StatusCode::kNotFound)); + Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } @@ -955,7 +956,7 @@ TEST_F(BuiltinsTest, MapStringIndex) { ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), - Eq(cel_base::StatusCode::kNotFound)); + Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } @@ -1368,6 +1369,23 @@ TEST_F(BuiltinsTest, MatchesMaxSize) { EXPECT_TRUE(result_value.IsError()); } +TEST_F(BuiltinsTest, StringToInt) { + std::string target = "-42"; + std::vector args = {CelValue::CreateString(&target)}; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInt, {}, args, &result_value)); + ASSERT_TRUE(result_value.IsInt64()); + EXPECT_EQ(result_value.Int64OrDie(), -42); +} + +TEST_F(BuiltinsTest, StringToIntNonInt) { + std::string target = "not_a_number"; + std::vector args = {CelValue::CreateString(&target)}; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInt, {}, args, &result_value)); + ASSERT_TRUE(result_value.IsError()); +} + TEST_F(BuiltinsTest, IntToString) { std::vector args = {CelValue::CreateInt64(-42)}; CelValue result_value; diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 21be23507..715045c6b 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -1,5 +1,7 @@ #include "eval/public/cel_expr_builder_factory.h" + #include "eval/compiler/flat_expr_builder.h" +#include "eval/public/cel_options.h" namespace google { namespace api { @@ -15,6 +17,19 @@ std::unique_ptr CreateCelExpressionBuilder( builder->set_enable_comprehension(options.enable_comprehension); builder->set_comprehension_max_iterations( options.comprehension_max_iterations); + builder->set_fail_on_warnings(options.fail_on_warnings); + + switch (options.unknown_processing) { + case UnknownProcessingOptions::kAttributeAndFunction: + builder->set_enable_unknown_function_results(true); + builder->set_enable_unknowns(true); + break; + case UnknownProcessingOptions::kAttributeOnly: + builder->set_enable_unknowns(true); + break; + case UnknownProcessingOptions::kDisabled: + break; + } return std::move(builder); } diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 74815369e..403c3f39a 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -22,13 +22,23 @@ namespace runtime { // then the order of the callback invocations is guaranteed to correspond // the order of variable sub-elements (e.g. the order of elements returned // by Comprehension.iter_range). -using CelEvaluationListener = std::function; +// An opaque state used for evaluation of a cell expression. +class CelEvaluationState { + public: + virtual ~CelEvaluationState() = default; +}; + // Base interface for expression evaluating objects. class CelExpression { public: - virtual ~CelExpression() {} + virtual ~CelExpression() = default; + + // Initializes the state + virtual std::unique_ptr InitializeState( + google::protobuf::Arena* arena) const = 0; // Evaluates expression and returns value. // activation contains bindings from parameter names to values @@ -37,10 +47,24 @@ class CelExpression { virtual cel_base::StatusOr Evaluate(const BaseActivation& activation, google::protobuf::Arena* arena) const = 0; + // Evaluates expression and returns value. + // activation contains bindings from parameter names to values + // state must be non-null and created prior to calling Evaluate by + // InitializeState. + virtual cel_base::StatusOr Evaluate( + const BaseActivation& activation, CelEvaluationState* state) const = 0; + // Trace evaluates expression calling the callback on each sub-tree. virtual cel_base::StatusOr Trace( const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const = 0; + + // Trace evaluates expression calling the callback on each sub-tree. + // state must be non-null and created prior to calling Evaluate by + // InitializeState. + virtual cel_base::StatusOr Trace( + const BaseActivation& activation, CelEvaluationState* state, + CelEvaluationListener callback) const = 0; }; // Base class for Expression Builder implementations @@ -60,6 +84,14 @@ class CelExpressionBuilder { const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const = 0; + // Creates CelExpression object from AST tree. + // expr specifies root of AST tree. + // non-fatal build warnings are written to warnings if encountered. + virtual cel_base::StatusOr> CreateExpression( + const google::api::expr::v1alpha1::Expr* expr, + const google::api::expr::v1alpha1::SourceInfo* source_info, + std::vector* warnings) const = 0; + // CelFunction registry. Extension function should be registered with it // prior to expression creation. CelFunctionRegistry* GetRegistry() const { return registry_.get(); } @@ -70,12 +102,12 @@ class CelExpressionBuilder { } // Add Enum to the list of resolvable by the builder. - void addResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { + void AddResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { resolvable_enums_.emplace(enum_descriptor); } // Remove Enum from the list of resolvable by the builder. - void removeResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { + void RemoveResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { resolvable_enums_.erase(enum_descriptor); } diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index ba6876c08..7e1dc0275 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -75,9 +75,9 @@ class CelFunction { // zero). When former happens, error Status is returned and *result is // not changed. In case of business logic error, returned Status is Ok, and // error is provided as CelValue - wrapped CelError in *result. - virtual ::cel_base::Status Evaluate(absl::Span arguments, - CelValue* result, - google::protobuf::Arena* arena) const = 0; + virtual absl::Status Evaluate(absl::Span arguments, + CelValue* result, + google::protobuf::Arena* arena) const = 0; // Determines whether instance of CelFunction is applicable to // arguments supplied. diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index ec4ad2d06..855af377b 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -4,10 +4,10 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" -#include "base/status.h" #include "base/statusor.h" namespace google { @@ -91,8 +91,8 @@ class FunctionAdapter : public CelFunction { std::vector arg_types; if (!internal::AddType<0, Arguments...>(&arg_types)) { - return cel_base::Status( - cel_base::StatusCode::kInternal, + return absl::Status( + absl::StatusCode::kInternal, absl::StrCat("Failed to create adapter for ", name, ": failed to determine input parameter type")); } @@ -105,7 +105,7 @@ class FunctionAdapter : public CelFunction { // Creates function handler and attempts to register it with // supplied function registry. - static cel_base::Status CreateAndRegister( + static absl::Status CreateAndRegister( absl::string_view name, bool receiver_type, std::function handler, CelFunctionRegistry* registry) { @@ -119,26 +119,26 @@ class FunctionAdapter : public CelFunction { #if defined(__clang_major_version__) && __clang_major_version__ >= 8 && !defined(__APPLE__) template - inline cel_base::Status RunWrap(absl::Span arguments, + inline absl::Status RunWrap(absl::Span arguments, std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, ::google::protobuf::Arena* arena) const { if (!ConvertFromValue(arguments[arg_index], &std::get(input))) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Type conversion failed"); } return RunWrap(arguments, input, result, arena); } template <> - inline cel_base::Status RunWrap( + inline absl::Status RunWrap( absl::Span, std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, ::google::protobuf::Arena* arena) const { return CreateReturnValue(absl::apply(handler_, input), arena, result); } #else - inline cel_base::Status RunWrap(std::function func, + inline absl::Status RunWrap(std::function func, const absl::Span argset, ::google::protobuf::Arena* arena, CelValue* result, int arg_index) const { @@ -146,13 +146,13 @@ class FunctionAdapter : public CelFunction { } template - inline cel_base::Status RunWrap(std::function func, + inline absl::Status RunWrap(std::function func, const absl::Span argset, ::google::protobuf::Arena* arena, CelValue* result, int arg_index) const { Arg argument; if (!ConvertFromValue(argset[arg_index], &argument)) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Type conversion failed"); } @@ -166,11 +166,10 @@ class FunctionAdapter : public CelFunction { } #endif - ::cel_base::Status Evaluate(absl::Span arguments, - CelValue* result, - ::google::protobuf::Arena* arena) const override { + absl::Status Evaluate(absl::Span arguments, CelValue* result, + ::google::protobuf::Arena* arena) const override { if (arguments.size() != sizeof...(Arguments)) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Argument number mismatch"); } @@ -201,91 +200,91 @@ class FunctionAdapter : public CelFunction { } // CreateReturnValue method wraps evaluation result with CelValue. - static cel_base::Status CreateReturnValue(bool value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(bool value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateBool(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(int64_t value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(int64_t value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateInt64(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(uint64_t value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(uint64_t value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateUint64(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(double value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(double value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateDouble(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(CelValue::StringHolder value, + static absl::Status CreateReturnValue(CelValue::StringHolder value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateString(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(CelValue::BytesHolder value, + static absl::Status CreateReturnValue(CelValue::BytesHolder value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateBytes(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const ::google::protobuf::Message* value, + static absl::Status CreateReturnValue(const ::google::protobuf::Message* value, ::google::protobuf::Arena* arena, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null Message pointer returned"); } *result = CelValue::CreateMessage(value, arena); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelList* value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelList* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelList pointer returned"); } *result = CelValue::CreateList(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelMap* value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelMap* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelMap pointer returned"); } *result = CelValue::CreateMap(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelError* value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelError* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelError pointer returned"); } *result = CelValue::CreateError(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelValue& value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelValue& value, ::google::protobuf::Arena*, CelValue* result) { *result = value; - return cel_base::OkStatus(); + return absl::OkStatus(); } template - static cel_base::Status CreateReturnValue(const cel_base::StatusOr& value, + static absl::Status CreateReturnValue(const cel_base::StatusOr& value, ::google::protobuf::Arena*, CelValue*) { if (!value) { return value.status(); diff --git a/eval/public/cel_function_adapter_test.cc b/eval/public/cel_function_adapter_test.cc index 89715c65e..8a6e1b50f 100644 --- a/eval/public/cel_function_adapter_test.cc +++ b/eval/public/cel_function_adapter_test.cc @@ -2,6 +2,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -15,7 +16,7 @@ TEST(CelFunctionAdapterTest, TestAdapterNoArg) { auto func_status = FunctionAdapter::Create("const", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -25,7 +26,7 @@ TEST(CelFunctionAdapterTest, TestAdapterNoArg) { google::protobuf::Arena arena; auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE( result.IsInt64()); // Obvious failure, for educational purposes only. @@ -37,7 +38,7 @@ TEST(CelFunctionAdapterTest, TestAdapterOneArg) { auto func_status = FunctionAdapter::Create("_++_", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -51,7 +52,7 @@ TEST(CelFunctionAdapterTest, TestAdapterOneArg) { auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 100); @@ -65,7 +66,7 @@ TEST(CelFunctionAdapterTest, TestAdapterTwoArgs) { auto func_status = FunctionAdapter::Create("_++_", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -80,7 +81,7 @@ TEST(CelFunctionAdapterTest, TestAdapterTwoArgs) { auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 42); @@ -100,7 +101,7 @@ TEST(CelFunctionAdapterTest, TestAdapterThreeArgs) { FunctionAdapter::Create("concat", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -120,7 +121,7 @@ TEST(CelFunctionAdapterTest, TestAdapterThreeArgs) { auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "123"); @@ -139,7 +140,7 @@ TEST(CelFunctionAdapterTest, TestTypeDeductionForCelValueBasicTypes) { absl::Duration, absl::Time, const CelList*, const CelMap*, const CelError*>::Create("dummy_func", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); diff --git a/eval/public/cel_function_provider.cc b/eval/public/cel_function_provider.cc index 42ce62e34..a695c65b9 100644 --- a/eval/public/cel_function_provider.cc +++ b/eval/public/cel_function_provider.cc @@ -22,7 +22,7 @@ class ActivationFunctionProviderImpl : public CelFunctionProvider { for (const CelFunction* overload : overloads) { if (overload->descriptor().ShapeMatches(descriptor)) { if (matching_overload != nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Couldn't resolve function."); } matching_overload = overload; diff --git a/eval/public/cel_function_provider_test.cc b/eval/public/cel_function_provider_test.cc index 80ffbbd9b..0f2d1ff41 100644 --- a/eval/public/cel_function_provider_test.cc +++ b/eval/public/cel_function_provider_test.cc @@ -2,6 +2,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -19,9 +20,9 @@ class ConstCelFunction : public CelFunction { ConstCelFunction() : CelFunction({"ConstFunction", false, {}}) {} explicit ConstCelFunction(const CelFunctionDescriptor& desc) : CelFunction(desc) {} - cel_base::Status Evaluate(absl::Span args, CelValue* output, + absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { - return cel_base::Status(cel_base::StatusCode::kUnimplemented, "Not Implemented"); + return absl::Status(absl::StatusCode::kUnimplemented, "Not Implemented"); } }; @@ -31,7 +32,7 @@ TEST(CreateActivationFunctionProviderTest, NoOverloadFound) { auto func = provider->GetFunction({"LazyFunc", false, {}}, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Eq(nullptr)); } @@ -42,11 +43,11 @@ TEST(CreateActivationFunctionProviderTest, OverloadFound) { auto status = activation.InsertFunction(std::make_unique(desc)); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); auto func = provider->GetFunction(desc, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Ne(nullptr)); } @@ -60,9 +61,9 @@ TEST(CreateActivationFunctionProviderTest, AmbiguousLookup) { auto status = activation.InsertFunction(std::make_unique(desc1)); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); status = activation.InsertFunction(std::make_unique(desc2)); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); auto func = provider->GetFunction(match_desc, activation); diff --git a/eval/public/cel_function_registry.cc b/eval/public/cel_function_registry.cc index 44cd8efcc..34202afe4 100644 --- a/eval/public/cel_function_registry.cc +++ b/eval/public/cel_function_registry.cc @@ -5,27 +5,27 @@ namespace api { namespace expr { namespace runtime { -cel_base::Status CelFunctionRegistry::Register( +absl::Status CelFunctionRegistry::Register( std::unique_ptr function) { const CelFunctionDescriptor& descriptor = function->descriptor(); if (DescriptorRegistered(descriptor)) { - return cel_base::Status( - cel_base::StatusCode::kAlreadyExists, + return absl::Status( + absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } auto& overloads = functions_[descriptor.name()]; overloads.static_overloads.push_back(std::move(function)); - return cel_base::OkStatus(); + return absl::OkStatus(); } -cel_base::Status CelFunctionRegistry::RegisterLazyFunction( +absl::Status CelFunctionRegistry::RegisterLazyFunction( const CelFunctionDescriptor& descriptor, std::unique_ptr factory) { if (DescriptorRegistered(descriptor)) { - return cel_base::Status( - cel_base::StatusCode::kAlreadyExists, + return absl::Status( + absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } auto& overloads = functions_[descriptor.name()]; @@ -33,7 +33,7 @@ cel_base::Status CelFunctionRegistry::RegisterLazyFunction( descriptor, std::move(factory)); overloads.lazy_overloads.push_back(std::move(entry)); - return cel_base::OkStatus(); + return absl::OkStatus(); } std::vector CelFunctionRegistry::FindOverloads( diff --git a/eval/public/cel_function_registry.h b/eval/public/cel_function_registry.h index fbeafac4e..de1d64bcc 100644 --- a/eval/public/cel_function_registry.h +++ b/eval/public/cel_function_registry.h @@ -26,18 +26,18 @@ class CelFunctionRegistry { // passed to registry. // Function registration should be performed prior to // CelExpression creation. - cel_base::Status Register(std::unique_ptr function); + absl::Status Register(std::unique_ptr function); // Register a lazily provided function. CelFunctionProvider is used to get // a CelFunction ptr at evaluation time. The registry takes ownership of the // factory. - cel_base::Status RegisterLazyFunction( + absl::Status RegisterLazyFunction( const CelFunctionDescriptor& descriptor, std::unique_ptr factory); // Register a lazily provided function. This overload uses a default provider // that delegates to the activation at evaluation time. - cel_base::Status RegisterLazyFunction(const CelFunctionDescriptor& descriptor) { + absl::Status RegisterLazyFunction(const CelFunctionDescriptor& descriptor) { return RegisterLazyFunction(descriptor, CreateActivationFunctionProvider()); } diff --git a/eval/public/cel_function_registry_test.cc b/eval/public/cel_function_registry_test.cc index 2a9297a24..a8cc6a97f 100644 --- a/eval/public/cel_function_registry_test.cc +++ b/eval/public/cel_function_registry_test.cc @@ -6,6 +6,7 @@ #include "gtest/gtest.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_provider.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -38,11 +39,11 @@ class ConstCelFunction : public CelFunction { return {"ConstFunction", false, {}}; } - cel_base::Status Evaluate(absl::Span args, CelValue* output, + absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { *output = CelValue::CreateInt64(42); - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -52,12 +53,12 @@ TEST(CelFunctionRegistryTest, InsertAndRetrieveLazyFunction) { Activation activation; auto register_status = registry.RegisterLazyFunction( lazy_function_desc, std::make_unique()); - EXPECT_TRUE(register_status.ok()); + EXPECT_OK(register_status); const auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(providers, testing::SizeIs(1)); auto func = providers[0]->GetFunction(lazy_function_desc, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Eq(nullptr)); } @@ -69,9 +70,9 @@ TEST(CelFunctionRegistryTest, LazyAndStaticFunctionShareDescriptorSpace) { CelFunctionDescriptor desc = ConstCelFunction::MakeDescriptor(); auto register_status = registry.RegisterLazyFunction( desc, std::make_unique()); - EXPECT_TRUE(register_status.ok()); + EXPECT_OK(register_status); - cel_base::Status status = registry.Register(std::make_unique()); + absl::Status status = registry.Register(std::make_unique()); EXPECT_FALSE(status.ok()); } @@ -81,8 +82,8 @@ TEST(CelFunctionRegistryTest, ListFunctions) { auto register_status = registry.RegisterLazyFunction( lazy_function_desc, std::make_unique()); - EXPECT_TRUE(register_status.ok()); - EXPECT_TRUE(registry.Register(std::make_unique()).ok()); + EXPECT_OK(register_status); + EXPECT_OK(registry.Register(std::make_unique())); auto registered_functions = registry.ListFunctions(); @@ -95,15 +96,15 @@ TEST(CelFunctionRegistryTest, DefaultLazyProvider) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; Activation activation; - EXPECT_TRUE(registry.RegisterLazyFunction(lazy_function_desc).ok()); + EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); auto insert_status = activation.InsertFunction( std::make_unique(lazy_function_desc)); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); const auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(providers, testing::SizeIs(1)); auto func = providers[0]->GetFunction(lazy_function_desc, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Property(&CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("LazyFunction")))); diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 46cbfb8d2..daff7f9dd 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -8,8 +8,23 @@ namespace api { namespace expr { namespace runtime { +// Options for unknown processing. +enum class UnknownProcessingOptions { + // No unknown processing. + kDisabled, + // Only attributes supported. + kAttributeOnly, + // Attributes and functions supported. Function results are dependent on the + // logic for handling unknown_attributes, so clients must opt in to both. + kAttributeAndFunction +}; + // Interpreter options for controlling evaluation and builtin functions. struct InterpreterOptions { + // Level of unknown support enabled. + UnknownProcessingOptions unknown_processing = + UnknownProcessingOptions::kDisabled; + // Enable short-circuiting of the logical operator evaluation. If enabled, // AND, OR, and TERNARY do not evaluate the entire expression once the the // resulting value is known from the left-hand side. @@ -20,6 +35,8 @@ struct InterpreterOptions { // Enable constant folding during the expression creation. If enabled, // an arena must be provided for constant generation. + // Note that expression tracing applies a modified expression if this option + // is enabled. bool constant_folding = false; google::protobuf::Arena* constant_arena = nullptr; @@ -50,6 +67,9 @@ struct InterpreterOptions { // Enable list membership overload. bool enable_list_contains = true; + + // Treat builder warnings as fatal errors. + bool fail_on_warnings = true; }; } // namespace runtime diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index b51e77698..000f3b2b4 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -4,6 +4,7 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/container/node_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" #include "internal/proto_util.h" @@ -40,6 +41,8 @@ constexpr char kErrNoMatchingOverload[] = "No matching overloads found"; constexpr char kErrNoSuchKey[] = "Key not found in map"; constexpr absl::string_view kErrUnknownValue = "Unknown value "; constexpr absl::string_view kPayloadUrlUnknownPath = "unknown_path"; +constexpr absl::string_view kPayloadUrlUnknownFunctionResult = + "cel_is_unknown_function_result"; // Forward declaration for google.protobuf.Value CelValue ValueFromMessage(const Value* value, Arena* arena); @@ -368,7 +371,7 @@ std::string CelValue::TypeName(Type value_type) { case Type::kMap: return "CelMap"; case Type::kUnknownSet: - return "UnknownAttributeSet"; + return "UnknownSet"; case Type::kError: return "CelError"; default: @@ -377,14 +380,14 @@ std::string CelValue::TypeName(Type value_type) { } CelValue CreateErrorValue(Arena* arena, absl::string_view message, - cel_base::StatusCode error_code, int) { + absl::StatusCode error_code, int) { CelError* error = Arena::Create(arena, error_code, message); return CelValue::CreateError(error); } CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena) { return CreateErrorValue(arena, kErrNoMatchingOverload, - cel_base::StatusCode::kUnknown); + absl::StatusCode::kUnknown); } bool CheckNoMatchingOverloadError(CelValue value) { @@ -393,11 +396,11 @@ bool CheckNoMatchingOverloadError(CelValue value) { } CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena) { - return CreateErrorValue(arena, "no_such_field", cel_base::StatusCode::kNotFound); + return CreateErrorValue(arena, "no_such_field", absl::StatusCode::kNotFound); } CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view) { - return CreateErrorValue(arena, kErrNoSuchKey, cel_base::StatusCode::kNotFound); + return CreateErrorValue(arena, kErrNoSuchKey, absl::StatusCode::kNotFound); } bool CheckNoSuchKeyError(CelValue value) { @@ -407,9 +410,9 @@ bool CheckNoSuchKeyError(CelValue value) { CelValue CreateUnknownValueError(google::protobuf::Arena* arena, absl::string_view unknown_path) { CelError* error = - Arena::Create(arena, cel_base::StatusCode::kUnavailable, + Arena::Create(arena, absl::StatusCode::kUnavailable, absl::StrCat(kErrUnknownValue, unknown_path)); - error->SetPayload(kPayloadUrlUnknownPath, cel_base::StatusCord(unknown_path)); + error->SetPayload(kPayloadUrlUnknownPath, absl::Cord(unknown_path)); return CelValue::CreateError(error); } @@ -417,7 +420,7 @@ bool IsUnknownValueError(const CelValue& value) { // TODO(issues/41): replace with the implementation of go/cel-known-unknowns if (!value.IsError()) return false; const CelError* error = value.ErrorOrDie(); - if (error && error->code() == cel_base::StatusCode::kUnavailable) { + if (error && error->code() == absl::StatusCode::kUnavailable) { auto path = error->GetPayload(kPayloadUrlUnknownPath); return path.has_value(); } @@ -427,7 +430,7 @@ bool IsUnknownValueError(const CelValue& value) { std::set GetUnknownPathsSetOrDie(const CelValue& value) { // TODO(issues/41): replace with the implementation of go/cel-known-unknowns const CelError* error = value.ErrorOrDie(); - if (error && error->code() == cel_base::StatusCode::kUnavailable) { + if (error && error->code() == absl::StatusCode::kUnavailable) { auto path = error->GetPayload(kPayloadUrlUnknownPath); if (path.has_value()) return {std::string(path.value())}; } @@ -435,6 +438,27 @@ std::set GetUnknownPathsSetOrDie(const CelValue& value) { return {}; } +CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, + absl::string_view help_message) { + CelError* error = Arena::Create( + arena, absl::StatusCode::kUnavailable, + absl::StrCat("Unknown function result: ", help_message)); + error->SetPayload(kPayloadUrlUnknownFunctionResult, absl::Cord("true")); + return CelValue::CreateError(error); +} + +bool IsUnknownFunctionResult(const CelValue& value) { + if (!value.IsError()) { + return false; + } + const CelError* error = value.ErrorOrDie(); + if (error == nullptr || error->code() != absl::StatusCode::kUnavailable) { + return false; + } + auto payload = error->GetPayload(kPayloadUrlUnknownFunctionResult); + return payload.has_value() && payload.value() == "true"; +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 173da9af3..16d1ce8cd 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -32,11 +32,12 @@ namespace api { namespace expr { namespace runtime { -using CelError = cel_base::Status; +using CelError = absl::Status; +// Break cyclic depdendencies for container types. class CelList; class CelMap; -class UnknownAttributeSet; +class UnknownSet; class CelValue { public: @@ -102,7 +103,7 @@ class CelValue { using ValueHolder = internal::ValueHolder< bool, int64_t, uint64_t, double, StringHolder, BytesHolder, const google::protobuf::Message *, absl::Duration, absl::Time, const CelList *, - const CelMap *, const UnknownAttributeSet *, const CelError *>; + const CelMap *, const UnknownSet *, const CelError *>; public: // Metafunction providing positions corresponding to specific @@ -123,7 +124,7 @@ class CelValue { kTimestamp = IndexOf::value, kList = IndexOf::value, kMap = IndexOf::value, - kUnknownSet = IndexOf::value, + kUnknownSet = IndexOf::value, kError = IndexOf::value, kAny // Special value. Used in function descriptors. }; @@ -197,7 +198,7 @@ class CelValue { return CelValue(value); } - static CelValue CreateUnknownSet(const UnknownAttributeSet *value) { + static CelValue CreateUnknownSet(const UnknownSet *value) { CheckNullPointer(value, Type::kUnknownSet); return CelValue(value); } @@ -270,8 +271,8 @@ class CelValue { // Returns stored const UnknownAttributeSet * value. // Fails if stored value type is not const UnknownAttributeSet *. - const UnknownAttributeSet *UnknownSetOrDie() const { - return GetValueOrDie(Type::kUnknownSet); + const UnknownSet *UnknownSetOrDie() const { + return GetValueOrDie(Type::kUnknownSet); } // Returns stored const CelError * value. @@ -304,7 +305,7 @@ class CelValue { bool IsMap() const { return value_.is(); } - bool IsUnknownSet() const { return value_.is(); } + bool IsUnknownSet() const { return value_.is(); } bool IsError() const { return value_.is(); } @@ -420,7 +421,7 @@ class CelMap { // parsed from. -1, if the position can not be determined. CelValue CreateErrorValue( google::protobuf::Arena *arena, absl::string_view message, - cel_base::StatusCode error_code = cel_base::StatusCode::kUnknown, + absl::StatusCode error_code = absl::StatusCode::kUnknown, int position = -1); CelValue CreateNoMatchingOverloadError(google::protobuf::Arena *arena); @@ -440,6 +441,18 @@ CelValue CreateUnknownValueError(google::protobuf::Arena *arena, // encountered a value marked as unknown in Activation unknown_paths. bool IsUnknownValueError(const CelValue &value); +// Returns error indicating the result of the function is unknown. This is used +// as a signal to create an unknown set if unknown function handling is opted +// into. +CelValue CreateUnknownFunctionResultError(google::protobuf::Arena *arena, + absl::string_view help_message); + +// Returns true if this is unknown value error indicating that evaluation +// called an extension function whose value is unknown for the given args. +// This is used as a signal to convert to an UnknownSet if the behavior is opted +// into. +bool IsUnknownFunctionResult(const CelValue &value); + // Returns set of unknown paths for unknown value error. The value must be // unknown error, see IsUnknownValueError() above, or it dies. std::set GetUnknownPathsSetOrDie(const CelValue &value); diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 6b82e77c9..735dc6a00 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -6,6 +6,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" @@ -94,7 +95,7 @@ TEST(CelValueTest, TestType) { CelValue value_timestamp2 = CelValue::CreateMessage(&msg_timestamp, &arena); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); - UnknownAttributeSet unknown_set; + UnknownSet unknown_set; CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); } @@ -132,7 +133,7 @@ int CountTypeMatch(const CelValue& value) { const CelError* value_error; count += (value.GetValue(&value_error)) ? 1 : 0; - const UnknownAttributeSet* value_unknown; + const UnknownSet* value_unknown; count += (value.GetValue(&value_unknown)) ? 1 : 0; return count; @@ -215,7 +216,6 @@ TEST(CelValueTest, TestString) { TEST(CelValueTest, TestBytes) { constexpr char kTestStr0[] = "test0"; std::string v = kTestStr0; - absl::string_view sv(v); CelValue value = CelValue::CreateBytes(&v); // CelValue value = CelValue::CreateString("test"); @@ -304,14 +304,14 @@ TEST(CelValueTest, TestMap) { // This test verifies CelValue support of Unknown type. TEST(CelValueTest, TestUnknownSet) { - UnknownAttributeSet unknown_set; + UnknownSet unknown_set; CelValue value = CelValue::CreateUnknownSet(&unknown_set); EXPECT_TRUE(value.IsUnknownSet()); EXPECT_THAT(value.UnknownSetOrDie(), Eq(&unknown_set)); // test template getter - const UnknownAttributeSet* value2; + const UnknownSet* value2; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_THAT(value2, Eq(&unknown_set)); EXPECT_THAT(CountTypeMatch(value), Eq(1)); @@ -585,6 +585,14 @@ TEST(CelValueTest, TestBytesWrapper) { EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); } +TEST(CelValueTest, UnknownFunctionResultErrors) { + ::google::protobuf::Arena arena; + + CelValue value = CreateUnknownFunctionResultError(&arena, "message"); + EXPECT_TRUE(value.IsError()); + EXPECT_TRUE(IsUnknownFunctionResult(value)); +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/extension_func_registrar.cc b/eval/public/extension_func_registrar.cc index 2611c24a7..1ac79e12c 100644 --- a/eval/public/extension_func_registrar.cc +++ b/eval/public/extension_func_registrar.cc @@ -7,8 +7,8 @@ namespace api { namespace expr { namespace runtime { -::cel_base::Status RegisterExtensionFunctions(CelFunctionRegistry*) { - return ::cel_base::OkStatus(); +absl::Status RegisterExtensionFunctions(CelFunctionRegistry*) { + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/extension_func_registrar.h b/eval/public/extension_func_registrar.h index 4eb68cefb..dcd5ebc4e 100644 --- a/eval/public/extension_func_registrar.h +++ b/eval/public/extension_func_registrar.h @@ -10,7 +10,7 @@ namespace expr { namespace runtime { // Register generic/widely used extension functions. -cel_base::Status RegisterExtensionFunctions(CelFunctionRegistry* registry); +absl::Status RegisterExtensionFunctions(CelFunctionRegistry* registry); } // namespace runtime } // namespace expr diff --git a/eval/public/extension_func_test.cc b/eval/public/extension_func_test.cc index 35b9c32c8..58f45633e 100644 --- a/eval/public/extension_func_test.cc +++ b/eval/public/extension_func_test.cc @@ -5,6 +5,7 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" #include "eval/public/extension_func_registrar.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -23,8 +24,8 @@ class ExtensionTest : public ::testing::Test { ExtensionTest() {} void SetUp() override { - ASSERT_TRUE(RegisterBuiltinFunctions(®istry_).ok()); - ASSERT_TRUE(RegisterExtensionFunctions(®istry_).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry_)); + ASSERT_OK(RegisterExtensionFunctions(®istry_)); } // Helper method to test string startsWith() function @@ -49,7 +50,7 @@ class ExtensionTest : public ::testing::Test { absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, &result_value, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); ASSERT_TRUE(result_value.IsBool()); ASSERT_EQ(result_value.BoolOrDie(), result); } @@ -79,7 +80,7 @@ class ExtensionTest : public ::testing::Test { absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); } // Helper method to test duration() function @@ -95,7 +96,7 @@ class ExtensionTest : public ::testing::Test { absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); } // Function registry object @@ -170,7 +171,7 @@ TEST_F(ExtensionTest, TestTimestampFromString) { // Invalid timestamp - empty string. EXPECT_NO_FATAL_FAILURE(PerformTimestampConversion(&arena, "", &result)); ASSERT_TRUE(result.IsError()); - ASSERT_EQ(result.ErrorOrDie()->code(), cel_base::StatusCode::kInvalidArgument); + ASSERT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); // Invalid timestamp. EXPECT_NO_FATAL_FAILURE( @@ -203,7 +204,7 @@ TEST_F(ExtensionTest, TestDurationFromString) { // Invalid duration - empty string. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "", &result)); ASSERT_TRUE(result.IsError()); - ASSERT_EQ(result.ErrorOrDie()->code(), cel_base::StatusCode::kInvalidArgument); + ASSERT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); // Invalid duration. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "100", &result)); diff --git a/eval/public/unknown_attribute_set.h b/eval/public/unknown_attribute_set.h index 87eedbe46..b3abdeeb2 100644 --- a/eval/public/unknown_attribute_set.h +++ b/eval/public/unknown_attribute_set.h @@ -19,29 +19,31 @@ class UnknownAttributeSet { UnknownAttributeSet& operator=(const UnknownAttributeSet& other) = default; UnknownAttributeSet() {} - UnknownAttributeSet( - const std::vector>& attributes) { + UnknownAttributeSet(const std::vector& attributes) { attributes_.reserve(attributes.size()); for (const auto& attr : attributes) { Add(attr); } } - std::vector> attributes() const { - return attributes_; + UnknownAttributeSet(const UnknownAttributeSet& set1, + const UnknownAttributeSet& set2) + : attributes_(set1.attributes()) { + attributes_.reserve(set1.attributes().size() + set2.attributes().size()); + for (const auto& attr : set2.attributes()) { + Add(attr); + } } + std::vector attributes() const { return attributes_; } + static UnknownAttributeSet Merge(const UnknownAttributeSet& set1, const UnknownAttributeSet& set2) { - UnknownAttributeSet attr_set = set1; - for (const auto& attr : set2.attributes()) { - attr_set.Add(attr); - } - return attr_set; + return UnknownAttributeSet(set1, set2); } private: - void Add(std::shared_ptr attribute) { + void Add(const CelAttribute* attribute) { if (!attribute) { return; } @@ -54,7 +56,7 @@ class UnknownAttributeSet { } // Attribute container. - std::vector> attributes_; + std::vector attributes_; }; } // namespace runtime diff --git a/eval/public/unknown_attribute_set_test.cc b/eval/public/unknown_attribute_set_test.cc index 7b98d6266..775628f4a 100644 --- a/eval/public/unknown_attribute_set_test.cc +++ b/eval/public/unknown_attribute_set_test.cc @@ -33,7 +33,7 @@ TEST(UnknownAttributeSetTest, TestCreate) { CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - UnknownAttributeSet unknown_set({cel_attr}); + UnknownAttributeSet unknown_set({cel_attr.get()}); EXPECT_THAT(unknown_set.attributes().size(), Eq(1)); EXPECT_THAT(*(unknown_set.attributes()[0]), Eq(*cel_attr)); } @@ -74,8 +74,8 @@ TEST(UnknownAttributeSetTest, TestMergeSets) { CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(false))})); - UnknownAttributeSet unknown_set1({cel_attr1, cel_attr2}); - UnknownAttributeSet unknown_set2({cel_attr1_copy, cel_attr3}); + UnknownAttributeSet unknown_set1({cel_attr1.get(), cel_attr2.get()}); + UnknownAttributeSet unknown_set2({cel_attr1_copy.get(), cel_attr3.get()}); UnknownAttributeSet unknown_set3 = UnknownAttributeSet::Merge(unknown_set1, unknown_set2); diff --git a/eval/public/unknown_function_result_set.cc b/eval/public/unknown_function_result_set.cc new file mode 100644 index 000000000..6e484da78 --- /dev/null +++ b/eval/public/unknown_function_result_set.cc @@ -0,0 +1,161 @@ +#include "eval/public/unknown_function_result_set.h" + +#include + +#include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +// Forward declare. +bool CelValueEqual(const CelValue lhs, const CelValue rhs); + +// Default to operator== +template +bool CelValueEqualImpl(T lhs, T rhs) { + return lhs == rhs; +} + +// List equality specialization. Test that the lists are in-order elementwise +// equal. +template <> +bool CelValueEqualImpl(const CelList* lhs, const CelList* rhs) { + if (lhs->size() != rhs->size()) { + return false; + } + for (int i = 0; i < rhs->size(); i++) { + if (!CelValueEqual(lhs->operator[](i), rhs->operator[](i))) { + return false; + } + } + return true; +} + +// Map equality specialization. Compare that two maps have exactly the same +// key/value pairs. +template <> +bool CelValueEqualImpl(const CelMap* lhs, const CelMap* rhs) { + if (lhs->size() != rhs->size()) { + return false; + } + const CelList* key_set = rhs->ListKeys(); + for (int i = 0; i < key_set->size(); i++) { + CelValue key = key_set->operator[](i); + CelValue rhs_value = rhs->operator[](key).value(); + auto maybe_lhs_value = lhs->operator[](key); + if (!maybe_lhs_value.has_value()) { + return false; + } + if (!CelValueEqual(maybe_lhs_value.value(), rhs_value)) { + return false; + } + } + return true; +} + +// Visitor for implementing comparing the underlying value that two CelValues +// are wrapping. The visitor unwraps the lhs then tries to get the rhs +// underlying value if it is the same type as the lhs. +struct LhsCompareVisitor { + CelValue rhs; + + LhsCompareVisitor(CelValue rhs) : rhs(rhs) {} + + template + bool operator()(T lhs_value) { + T rhs_value; + bool is_same_type = rhs.GetValue(&rhs_value); + if (!is_same_type) { + return false; + } + return CelValueEqualImpl(lhs_value, rhs_value); + } +}; + +// This is a slightly different implementation than provided for the cel +// evaluator. Differences are: +// +// - this implementation doesn't need to support error forwarding in the same +// way -- this should only be used for situations when we can invoke the +// function. i.e. the function must specify that it consumes errors and/or +// unknown sets for them to appear in the arg list. +// - this implementation defines equality between messages based on ptr identity +bool CelValueEqual(const CelValue lhs, const CelValue rhs) { + if (lhs.type() != rhs.type()) { + return false; + } + return lhs.Visit(LhsCompareVisitor(rhs)); +} + +// Tests that two descriptors are equal (name, receiver call style, arg types). +// +// Argument type Any is not treated specially. For example: +// {"f", false, {kAny}} != {"f", false, {kInt64}} +bool DescriptorEqual(const CelFunctionDescriptor& lhs, + const CelFunctionDescriptor& rhs) { + if (lhs.name() != rhs.name()) { + return false; + } + + if (lhs.receiver_style() != rhs.receiver_style()) { + return false; + } + + if (lhs.types() != rhs.types()) { + return false; + } + + return true; +} + +} // namespace + +bool UnknownFunctionResult::IsEqualTo( + const UnknownFunctionResult& other) const { + if (!DescriptorEqual(descriptor_, other.descriptor())) { + return false; + } + + if (arguments_.size() != other.arguments().size()) { + return false; + } + + for (size_t i = 0; i < arguments_.size(); i++) { + if (!CelValueEqual(arguments_[i], other.arguments()[i])) { + return false; + } + } + + return true; +} + +// Implementation for merge constructor. +UnknownFunctionResultSet::UnknownFunctionResultSet( + const UnknownFunctionResultSet& lhs, const UnknownFunctionResultSet& rhs) + : unknown_function_results_(lhs.unknown_function_results()) { + unknown_function_results_.reserve(lhs.unknown_function_results().size() + + rhs.unknown_function_results().size()); + for (const UnknownFunctionResult* call : rhs.unknown_function_results()) { + Add(call); + } +} + +void UnknownFunctionResultSet::Add(const UnknownFunctionResult* result) { + for (const UnknownFunctionResult* existing_result : + unknown_function_results()) { + if (result->IsEqualTo(*existing_result)) { + return; + } + } + unknown_function_results_.push_back(result); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/unknown_function_result_set.h b/eval/public/unknown_function_result_set.h new file mode 100644 index 000000000..27c6cbc9b --- /dev/null +++ b/eval/public/unknown_function_result_set.h @@ -0,0 +1,75 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ + +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "eval/public/cel_function.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Represents a function result that is unknown at the time of execution. This +// allows for lazy evaluation of expensive functions. +class UnknownFunctionResult { + public: + UnknownFunctionResult(const CelFunctionDescriptor& descriptor, int64_t expr_id, + const std::vector& arguments) + : descriptor_(descriptor), expr_id_(expr_id), arguments_(arguments) {} + + // The descriptor of the called function that return Unknown. + const CelFunctionDescriptor& descriptor() const { return descriptor_; } + + // The id of the |Expr| that triggered the function call step. Provided + // informationally -- if two different |Expr|s generate the same unknown call, + // they will be treated as the same unknown function result. + int64_t call_expr_id() const { return expr_id_; } + + // The arguments of the function call that generated the unknown. + const std::vector& arguments() const { return arguments_; } + + // Equality operator provided for set semantics. + // Compares descriptor then arguments elementwise. + bool IsEqualTo(const UnknownFunctionResult& other) const; + + private: + CelFunctionDescriptor descriptor_; + int64_t expr_id_; + std::vector arguments_; +}; + +// Represents a collection of unknown function results at a particular point in +// execution. Execution should advance further if this set of unknowns are +// provided. It may not advance if only a subset are provided. +// Set semantics use |IsEqualTo()| defined on |UnknownFunctionResult|. +class UnknownFunctionResultSet { + public: + // Empty set + UnknownFunctionResultSet() {} + + // Merge constructor -- effectively union(lhs, rhs). + UnknownFunctionResultSet(const UnknownFunctionResultSet& lhs, + const UnknownFunctionResultSet& rhs); + + // Initialize with a single UnknownFunctionResult. + UnknownFunctionResultSet(const UnknownFunctionResult* initial) + : unknown_function_results_{initial} {} + + const std::vector& unknown_function_results() + const { + return unknown_function_results_; + } + + private: + std::vector unknown_function_results_; + void Add(const UnknownFunctionResult* result); +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ diff --git a/eval/public/unknown_function_result_set_test.cc b/eval/public/unknown_function_result_set_test.cc new file mode 100644 index 000000000..35d438a62 --- /dev/null +++ b/eval/public/unknown_function_result_set_test.cc @@ -0,0 +1,487 @@ +#include "eval/public/unknown_function_result_set.h" + +#include + +#include + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/empty.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/arena.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "eval/eval/container_backed_list_impl.h" +#include "eval/eval/container_backed_map_impl.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_value.h" +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +using ::google::protobuf::ListValue; +using ::google::protobuf::Struct; +using ::google::protobuf::Arena; +using testing::Eq; +using testing::SizeIs; + +CelFunctionDescriptor kTwoInt("TwoInt", false, + {CelValue::Type::kInt64, CelValue::Type::kInt64}); + +CelFunctionDescriptor kOneInt("OneInt", false, {CelValue::Type::kInt64}); + +TEST(UnknownFunctionResult, ArgumentCapture) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + EXPECT_THAT(call1.arguments(), SizeIs(2)); + EXPECT_THAT(call1.arguments().at(0).Int64OrDie(), Eq(1)); +} + +TEST(UnknownFunctionResult, Equals) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call2( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + EXPECT_TRUE(call1.IsEqualTo(call2)); + + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + UnknownFunctionResult call4(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + EXPECT_TRUE(call3.IsEqualTo(call4)); +} + +TEST(UnknownFunctionResult, InequalDescriptor) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call2(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + EXPECT_FALSE(call1.IsEqualTo(call2)); + + CelFunctionDescriptor one_uint("OneInt", false, {CelValue::Type::kUint64}); + + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + UnknownFunctionResult call4(one_uint, /*expr_id=*/0, + {CelValue::CreateUint64(1)}); + + EXPECT_FALSE(call3.IsEqualTo(call4)); +} + +TEST(UnknownFunctionResult, InequalArgs) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call2( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(3)}); + + EXPECT_FALSE(call1.IsEqualTo(call2)); + + UnknownFunctionResult call3( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call4(kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + EXPECT_FALSE(call3.IsEqualTo(call4)); +} + +TEST(UnknownFunctionResult, ListsEqual) { + ContainerBackedListImpl cel_list_1(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + ContainerBackedListImpl cel_list_2(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_1)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_2)}); + + // [1, 2] == [1, 2] + EXPECT_TRUE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, ListsDifferentSizes) { + ContainerBackedListImpl cel_list_1(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + ContainerBackedListImpl cel_list_2(std::vector{ + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3), + }); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_1)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_2)}); + + // [1, 2] == [1, 2, 3] + EXPECT_FALSE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, ListsDifferentMembers) { + ContainerBackedListImpl cel_list_1(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + ContainerBackedListImpl cel_list_2(std::vector{ + CelValue::CreateInt64(2), CelValue::CreateInt64(2)}); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_1)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_2)}); + + // [1, 2] == [2, 2] + EXPECT_FALSE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, MapsEqual) { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; + + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values)); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_1.get())}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_2.get())}); + + // {1: 2, 2: 4} == {1: 2, 2: 4} + EXPECT_TRUE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, MapsDifferentSizes) { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; + + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + + std::vector> values2{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; + + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_1.get())}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_2.get())}); + + // {1: 2, 2: 4} == {1: 2, 2: 4, 3: 6} + EXPECT_FALSE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, MapsDifferentElements) { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; + + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + + std::vector> values2{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(4), CelValue::CreateInt64(8)}}; + + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); + + std::vector> values3{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(3), CelValue::CreateInt64(5)}}; + + auto cel_map_3 = CreateContainerBackedMap(absl::MakeSpan(values3)); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_1.get())}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_2.get())}); + UnknownFunctionResult call3(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_3.get())}); + + // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 4: 8} + EXPECT_FALSE(call1.IsEqualTo(call2)); + // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 3: 5} + EXPECT_FALSE(call1.IsEqualTo(call3)); +} + +TEST(UnknownFunctionResult, Messages) { + protobuf::Empty message1; + protobuf::Empty message2; + google::protobuf::Arena arena; + + CelFunctionDescriptor desc("OneMessage", false, {CelValue::Type::kMessage}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMessage(&message1, &arena)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMessage(&message2, &arena)}); + UnknownFunctionResult call3(desc, /*expr_id=*/0, + {CelValue::CreateMessage(&message1, &arena)}); + + // &message1 == &message2 + EXPECT_FALSE(call1.IsEqualTo(call2)); + + // &message1 == &message1 + EXPECT_TRUE(call1.IsEqualTo(call3)); +} + +TEST(UnknownFunctionResult, AnyDescriptor) { + CelFunctionDescriptor anyDesc("OneAny", false, {CelValue::Type::kAny}); + + UnknownFunctionResult callAnyInt1(anyDesc, /*expr_id=*/0, + {CelValue::CreateInt64(2)}); + UnknownFunctionResult callInt(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(2)}); + + UnknownFunctionResult callAnyInt2(anyDesc, /*expr_id=*/0, + {CelValue::CreateInt64(2)}); + UnknownFunctionResult callAnyUint(anyDesc, /*expr_id=*/0, + {CelValue::CreateUint64(2)}); + + EXPECT_FALSE(callAnyInt1.IsEqualTo(callInt)); + EXPECT_FALSE(callAnyInt1.IsEqualTo(callAnyUint)); + EXPECT_TRUE(callAnyInt1.IsEqualTo(callAnyInt2)); +} + +TEST(UnknownFunctionResult, Strings) { + CelFunctionDescriptor desc("OneString", false, {CelValue::Type::kString}); + + UnknownFunctionResult callStringSmile(desc, /*expr_id=*/0, + {CelValue::CreateStringView("😁")}); + UnknownFunctionResult callStringFrown(desc, /*expr_id=*/0, + {CelValue::CreateStringView("🙁")}); + UnknownFunctionResult callStringSmile2(desc, /*expr_id=*/0, + {CelValue::CreateStringView("😁")}); + + EXPECT_TRUE(callStringSmile.IsEqualTo(callStringSmile2)); + EXPECT_FALSE(callStringSmile.IsEqualTo(callStringFrown)); +} + +TEST(UnknownFunctionResult, DurationHandling) { + google::protobuf::Arena arena; + absl::Duration duration1 = absl::Seconds(5); + protobuf::Duration duration2; + duration2.set_seconds(5); + + CelFunctionDescriptor durationDesc("OneDuration", false, + {CelValue::Type::kDuration}); + + UnknownFunctionResult callDuration1(durationDesc, /*expr_id=*/0, + {CelValue::CreateDuration(duration1)}); + UnknownFunctionResult callDuration2( + durationDesc, /*expr_id=*/0, + {CelValue::CreateMessage(&duration2, &arena)}); + UnknownFunctionResult callDuration3(durationDesc, /*expr_id=*/0, + {CelValue::CreateDuration(&duration2)}); + + EXPECT_TRUE(callDuration1.IsEqualTo(callDuration2)); + EXPECT_TRUE(callDuration1.IsEqualTo(callDuration3)); +} + +TEST(UnknownFunctionResult, TimestampHandling) { + google::protobuf::Arena arena; + absl::Time ts1 = absl::FromUnixMillis(1000); + protobuf::Timestamp ts2; + ts2.set_seconds(1); + + CelFunctionDescriptor timestampDesc("OneTimestamp", false, + {CelValue::Type::kTimestamp}); + + UnknownFunctionResult callTimestamp1(timestampDesc, /*expr_id=*/0, + {CelValue::CreateTimestamp(ts1)}); + UnknownFunctionResult callTimestamp2(timestampDesc, /*expr_id=*/0, + {CelValue::CreateMessage(&ts2, &arena)}); + UnknownFunctionResult callTimestamp3(timestampDesc, /*expr_id=*/0, + {CelValue::CreateTimestamp(&ts2)}); + + EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp2)); + EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp3)); +} + +// This tests that the conversion and different map backing implementations are +// compatible with the equality tests. +TEST(UnknownFunctionResult, ProtoStructTreatedAsMap) { + Arena arena; + + const std::vector kFields = {"field1", "field2", "field3"}; + + Struct value_struct; + + auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; + value1.set_bool_value(true); + + auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; + value2.set_number_value(1.0); + + auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; + value3.set_string_value("test"); + + CelValue proto_struct = CelValue::CreateMessage(&value_struct, &arena); + ASSERT_TRUE(proto_struct.IsMap()); + + std::vector> values{ + {CelValue::CreateStringView(kFields[2]), + CelValue::CreateStringView("test")}, + {CelValue::CreateStringView(kFields[1]), CelValue::CreateDouble(1.0)}, + {CelValue::CreateStringView(kFields[0]), CelValue::CreateBool(true)}}; + + auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); + + CelValue cel_map = CelValue::CreateMap(backing_map.get()); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, + {proto_struct}); + UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); + + EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); +} + +// This tests that the conversion and different map backing implementations are +// compatible with the equality tests. +TEST(UnknownFunctionResult, ProtoListTreatedAsList) { + Arena arena; + + ListValue list_value; + + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + + CelValue proto_list = CelValue::CreateMessage(&list_value, &arena); + ASSERT_TRUE(proto_list.IsList()); + + std::vector list_values{CelValue::CreateBool(true), + CelValue::CreateDouble(1.0), + CelValue::CreateStringView("test")}; + + ContainerBackedListImpl list_backing(list_values); + + CelValue cel_list = CelValue::CreateList(&list_backing); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult proto_list_result(desc, /*expr_id=*/0, {proto_list}); + UnknownFunctionResult cel_list_result(desc, /*expr_id=*/0, {cel_list}); + + EXPECT_TRUE(cel_list_result.IsEqualTo(proto_list_result)); +} + +TEST(UnknownFunctionResult, NestedProtoTypes) { + Arena arena; + + ListValue list_value; + + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + + std::vector list_values{CelValue::CreateBool(true), + CelValue::CreateDouble(1.0), + CelValue::CreateStringView("test")}; + + ContainerBackedListImpl list_backing(list_values); + + CelValue cel_list = CelValue::CreateList(&list_backing); + + Struct value_struct; + + *(value_struct.mutable_fields()->operator[]("field").mutable_list_value()) = + list_value; + + std::vector> values{ + {CelValue::CreateStringView("field"), cel_list}}; + + auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); + + CelValue cel_map = CelValue::CreateMap(backing_map.get()); + CelValue proto_map = CelValue::CreateMessage(&value_struct, &arena); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); + UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, {proto_map}); + + EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); +} + +UnknownFunctionResult MakeUnknown(int64_t i) { + return UnknownFunctionResult(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(i)}); +} + +testing::Matcher UnknownMatches( + const UnknownFunctionResult& obj) { + return testing::Truly([&](const UnknownFunctionResult* to_match) { + return obj.IsEqualTo(*to_match); + }); +} + +TEST(UnknownFunctionResultSet, Merge) { + UnknownFunctionResult a = MakeUnknown(1); + UnknownFunctionResult b = MakeUnknown(2); + UnknownFunctionResult c = MakeUnknown(3); + UnknownFunctionResult d = MakeUnknown(1); + + UnknownFunctionResultSet a1(&a); + UnknownFunctionResultSet b1(&b); + UnknownFunctionResultSet c1(&c); + UnknownFunctionResultSet d1(&d); + + UnknownFunctionResultSet ab(a1, b1); + UnknownFunctionResultSet cd(c1, d1); + + UnknownFunctionResultSet merged(ab, cd); + + EXPECT_THAT(merged.unknown_function_results(), SizeIs(3)); + EXPECT_THAT(merged.unknown_function_results(), + testing::UnorderedElementsAre( + UnknownMatches(a), UnknownMatches(b), UnknownMatches(c))); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/unknown_set.h b/eval/public/unknown_set.h new file mode 100644 index 000000000..3b7168afe --- /dev/null +++ b/eval/public/unknown_set.h @@ -0,0 +1,52 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ + +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_function_result_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Class representing a collection of unknowns from a single evaluation pass of +// a CEL expression. +class UnknownSet { + public: + // Initilization specifying subcontainers + explicit UnknownSet( + const google::api::expr::runtime::UnknownAttributeSet& attrs) + : unknown_attributes_(attrs) {} + explicit UnknownSet(const UnknownFunctionResultSet& function_results) + : unknown_function_results_(function_results) {} + UnknownSet(const UnknownAttributeSet& attrs, + const UnknownFunctionResultSet& function_results) + : unknown_attributes_(attrs), + unknown_function_results_(function_results) {} + // Initialization for empty set + UnknownSet() {} + // Merge constructor + UnknownSet(const UnknownSet& set1, const UnknownSet& set2) + : unknown_attributes_(set1.unknown_attributes(), + set2.unknown_attributes()), + unknown_function_results_(set1.unknown_function_results(), + set2.unknown_function_results()) {} + + const UnknownAttributeSet& unknown_attributes() const { + return unknown_attributes_; + } + const UnknownFunctionResultSet& unknown_function_results() const { + return unknown_function_results_; + } + + private: + UnknownAttributeSet unknown_attributes_; + UnknownFunctionResultSet unknown_function_results_; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ diff --git a/eval/public/unknown_set_test.cc b/eval/public/unknown_set_test.cc new file mode 100644 index 000000000..3e4c06cda --- /dev/null +++ b/eval/public/unknown_set_test.cc @@ -0,0 +1,129 @@ +#include "eval/public/unknown_set.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_function_result_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +using ::google::protobuf::Arena; +using testing::IsEmpty; +using testing::UnorderedElementsAre; + +UnknownFunctionResultSet MakeFunctionResult(Arena* arena, int64_t id) { + CelFunctionDescriptor desc("OneInt", false, {CelValue::Type::kInt64}); + std::vector call_args{CelValue::CreateInt64(id)}; + const auto* function_result = Arena::Create( + arena, desc, /*expr_id=*/0, call_args); + return UnknownFunctionResultSet(function_result); +} + +UnknownAttributeSet MakeAttribute(Arena* arena, int64_t id) { + google::api::expr::v1alpha1::Expr expr; + expr.mutable_ident_expr()->set_name("x"); + + std::vector attr_trail{ + CelAttributeQualifier::Create(CelValue::CreateInt64(id))}; + + const auto* attr = Arena::Create(arena, expr, attr_trail); + return UnknownAttributeSet({attr}); +} + +MATCHER_P(UnknownAttributeIs, id, "") { + const CelAttribute* attr = arg; + if (attr->qualifier_path().size() != 1) { + return false; + } + auto maybe_qualifier = attr->qualifier_path()[0].GetInt64Key(); + if (!maybe_qualifier.has_value()) { + return false; + } + return maybe_qualifier.value() == id; +} + +MATCHER_P(UnknownFunctionResultIs, id, "") { + const UnknownFunctionResult* result = arg; + if (result->arguments().size() != 1) { + return false; + } + if (!result->arguments()[0].IsInt64()) { + return false; + } + return result->arguments()[0].Int64OrDie() == id; +} + +TEST(UnknownSet, AttributesMerge) { + Arena arena; + UnknownSet a(MakeAttribute(&arena, 1)); + UnknownSet b(MakeAttribute(&arena, 2)); + UnknownSet c(MakeAttribute(&arena, 2)); + UnknownSet d(a, b); + UnknownSet e(c, d); + + EXPECT_THAT( + d.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); + EXPECT_THAT( + e.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); +} + +TEST(UnknownSet, FunctionsMerge) { + Arena arena; + + UnknownSet a(MakeFunctionResult(&arena, 1)); + UnknownSet b(MakeFunctionResult(&arena, 2)); + UnknownSet c(MakeFunctionResult(&arena, 2)); + UnknownSet d(a, b); + UnknownSet e(c, d); + + EXPECT_THAT(d.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); + EXPECT_THAT(e.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); +} + +TEST(UnknownSet, DefaultEmpty) { + UnknownSet empty_set; + EXPECT_THAT(empty_set.unknown_attributes().attributes(), IsEmpty()); + EXPECT_THAT(empty_set.unknown_function_results().unknown_function_results(), + IsEmpty()); +} + +TEST(UnknownSet, MixedMerges) { + Arena arena; + + UnknownSet a(MakeAttribute(&arena, 1), MakeFunctionResult(&arena, 1)); + UnknownSet b(MakeFunctionResult(&arena, 2)); + UnknownSet c(MakeAttribute(&arena, 2)); + UnknownSet d(a, b); + UnknownSet e(c, d); + + EXPECT_THAT(d.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1))); + EXPECT_THAT(d.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); + EXPECT_THAT( + e.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); + EXPECT_THAT(e.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index b557d6804..de72c135e 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -4,7 +4,6 @@ #include "google/protobuf/util/time_util.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -20,7 +19,7 @@ using google::protobuf::FieldDescriptor; using google::protobuf::Message; using google::protobuf::util::TimeUtil; -cel_base::Status KeyAsString(const CelValue& value, std::string* key) { +absl::Status KeyAsString(const CelValue& value, std::string* key) { switch (value.type()) { case CelValue::Type::kInt64: { *key = absl::StrCat(value.Int64OrDie()); @@ -35,16 +34,18 @@ cel_base::Status KeyAsString(const CelValue& value, std::string* key) { value.StringOrDie().value().size()); break; } - default: { return cel_base::InvalidArgumentError("Unsupported map type"); } + default: { + return absl::InvalidArgumentError("Unsupported map type"); + } } - return cel_base::OkStatus(); + return absl::OkStatus(); } // Export content of CelValue as google.protobuf.Value. -cel_base::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { +absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { if (in_value.IsNull()) { out_value->set_null_value(google::protobuf::NULL_VALUE); - return cel_base::OkStatus(); + return absl::OkStatus(); } switch (in_value.type()) { case CelValue::Type::kBool: { @@ -92,13 +93,13 @@ cel_base::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) auto status = google::protobuf::util::MessageToJsonString(*in_value.MessageOrDie(), &json, json_options); if (!status.ok()) { - return cel_base::InternalError(status.ToString()); + return absl::InternalError(status.ToString()); } google::protobuf::util::JsonParseOptions json_parse_options; status = google::protobuf::util::JsonStringToMessage(json, out_value, json_parse_options); if (!status.ok()) { - return cel_base::InternalError(status.ToString()); + return absl::InternalError(status.ToString()); } break; } @@ -135,9 +136,11 @@ cel_base::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) } break; } - default: { return cel_base::InvalidArgumentError("Unsupported value type"); } + default: { + return absl::InvalidArgumentError("Unsupported value type"); + } } - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/value_export_util.h b/eval/public/value_export_util.h index 39018b3d7..6fbf9f8c4 100644 --- a/eval/public/value_export_util.h +++ b/eval/public/value_export_util.h @@ -1,10 +1,9 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ -#include "eval/public/cel_value.h" - #include "google/protobuf/struct.pb.h" -#include "base/status.h" +#include "absl/status/status.h" +#include "eval/public/cel_value.h" namespace google { namespace api { @@ -16,7 +15,7 @@ namespace runtime { // - exports integer values as doubles (Value.number_value); // - exports integer keys in maps as strings; // - handles Duration and Timestamp as generic messages. -cel_base::Status ExportAsProtoValue(const CelValue &in_value, +absl::Status ExportAsProtoValue(const CelValue &in_value, google::protobuf::Value *out_value); } // namespace runtime diff --git a/eval/public/value_export_util_test.cc b/eval/public/value_export_util_test.cc index e988ebe2d..7dd7773bd 100644 --- a/eval/public/value_export_util_test.cc +++ b/eval/public/value_export_util_test.cc @@ -9,6 +9,7 @@ #include "eval/eval/container_backed_map_impl.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -27,7 +28,7 @@ using google::protobuf::Arena; TEST(ValueExportUtilTest, ConvertBoolValue) { CelValue cel_value = CelValue::CreateBool(true); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kBoolValue); EXPECT_EQ(value.bool_value(), true); } @@ -35,7 +36,7 @@ TEST(ValueExportUtilTest, ConvertBoolValue) { TEST(ValueExportUtilTest, ConvertInt64Value) { CelValue cel_value = CelValue::CreateInt64(-1); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), -1); } @@ -43,7 +44,7 @@ TEST(ValueExportUtilTest, ConvertInt64Value) { TEST(ValueExportUtilTest, ConvertUint64Value) { CelValue cel_value = CelValue::CreateUint64(1); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), 1); } @@ -51,7 +52,7 @@ TEST(ValueExportUtilTest, ConvertUint64Value) { TEST(ValueExportUtilTest, ConvertDoubleValue) { CelValue cel_value = CelValue::CreateDouble(1.3); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), 1.3); } @@ -60,7 +61,7 @@ TEST(ValueExportUtilTest, ConvertStringValue) { std::string test = "test"; CelValue cel_value = CelValue::CreateString(&test); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "test"); } @@ -69,7 +70,7 @@ TEST(ValueExportUtilTest, ConvertBytesValue) { std::string test = "test"; CelValue cel_value = CelValue::CreateBytes(&test); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); // Check that the result is BASE64 encoded. EXPECT_EQ(value.string_value(), "dGVzdA=="); @@ -81,7 +82,7 @@ TEST(ValueExportUtilTest, ConvertDurationValue) { duration.set_nanos(3); CelValue cel_value = CelValue::CreateDuration(&duration); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "2.000000003s"); } @@ -92,7 +93,7 @@ TEST(ValueExportUtilTest, ConvertTimestampValue) { timestamp.set_nanos(3); CelValue cel_value = CelValue::CreateTimestamp(×tamp); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "2001-09-09T01:46:40.000000003Z"); } @@ -103,7 +104,7 @@ TEST(ValueExportUtilTest, ConvertStructMessage) { Arena arena; CelValue cel_value = CelValue::CreateMessage(&struct_msg, &arena); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); EXPECT_THAT(value.struct_value(), testutil::EqualsProto(struct_msg)); } @@ -116,7 +117,7 @@ TEST(ValueExportUtilTest, ConvertValueMessage) { Arena arena; CelValue cel_value = CelValue::CreateMessage(&value_in, &arena); Value value_out; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value_out).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value_out)); EXPECT_THAT(value_in, testutil::EqualsProto(value_out)); } @@ -127,7 +128,7 @@ TEST(ValueExportUtilTest, ConvertListValueMessage) { Arena arena; CelValue cel_value = CelValue::CreateMessage(&list_value, &arena); Value value_out; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value_out).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value_out)); EXPECT_THAT(list_value, testutil::EqualsProto(value_out.list_value())); } @@ -140,7 +141,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedBoolValue) { msg->add_bool_list(false); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("bool_list"); @@ -159,7 +160,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedInt32Value) { msg->add_int32_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("int32_list"); @@ -178,7 +179,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedInt64Value) { msg->add_int64_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("int64_list"); @@ -197,7 +198,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedUint64Value) { msg->add_uint64_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("uint64_list"); @@ -216,7 +217,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedDoubleValue) { msg->add_double_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("double_list"); @@ -235,7 +236,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedStringValue) { msg->add_string_list("test2"); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("string_list"); @@ -254,7 +255,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedBytesValue) { msg->add_bytes_list("test2"); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("bytes_list"); @@ -274,7 +275,7 @@ TEST(ValueExportUtilTest, ConvertCelList) { CelList *cel_list = Arena::Create(&arena, values); CelValue cel_value = CelValue::CreateList(cel_list); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kListValue); EXPECT_DOUBLE_EQ(value.list_value().values(0).number_value(), 2); @@ -299,7 +300,7 @@ TEST(ValueExportUtilTest, ConvertCelMapWithStringKey) { absl::Span>(map_entries)); CelValue cel_value = CelValue::CreateMap(cel_map.get()); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); const auto &fields = value.struct_value().fields(); @@ -326,7 +327,7 @@ TEST(ValueExportUtilTest, ConvertCelMapWithInt64Key) { absl::Span>(map_entries)); CelValue cel_value = CelValue::CreateMap(cel_map.get()); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); const auto &fields = value.struct_value().fields(); diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 273f1a2eb..ce8a80822 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -41,6 +41,7 @@ cc_test( ], copts = ["-std=c++14"], deps = [ + "//base:status_macros", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", @@ -48,12 +49,35 @@ cc_test( "//eval/public:cel_value", "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) +cc_test( + name = "unknowns_end_to_end_test", + size = "small", + srcs = [ + "unknowns_end_to_end_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + "//base:status_macros", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_set", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + proto_library( name = "request_context_protos", srcs = [ @@ -65,3 +89,15 @@ cc_proto_library( name = "request_context_cc_proto", deps = [":request_context_protos"], ) + +cc_library( + name = "mock_cel_expression", + testonly = 1, + hdrs = ["mock_cel_expression.h"], + copts = ["-std=c++14"], + deps = [ + "//eval/public:activation", + "//eval/public:cel_expression", + "@com_github_google_googletest//:gtest_main", + ], +) diff --git a/eval/tests/end_to_end_test.cc b/eval/tests/end_to_end_test.cc index 7c267b94e..c95f919b3 100644 --- a/eval/tests/end_to_end_test.cc +++ b/eval/tests/end_to_end_test.cc @@ -8,6 +8,7 @@ #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -49,12 +50,12 @@ TEST(EndToEndTest, SimpleOnePlusOne) { std::unique_ptr builder = CreateCelExpressionBuilder(); // Builtin registration. - ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - ASSERT_TRUE(cel_expression_status.ok()); + ASSERT_OK(cel_expression_status); auto cel_expression = std::move(cel_expression_status.ValueOrDie()); @@ -68,7 +69,7 @@ TEST(EndToEndTest, SimpleOnePlusOne) { // Run evaluation. auto eval_status = cel_expression->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); @@ -133,12 +134,12 @@ TEST(EndToEndTest, EmptyStringCompare) { std::unique_ptr builder = CreateCelExpressionBuilder(); // Builtin registration. - ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - ASSERT_TRUE(cel_expression_status.ok()); + ASSERT_OK(cel_expression_status); auto cel_expression = std::move(cel_expression_status.ValueOrDie()); @@ -158,7 +159,7 @@ TEST(EndToEndTest, EmptyStringCompare) { // Run evaluation. auto eval_status = cel_expression->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); diff --git a/eval/tests/mock_cel_expression.h b/eval/tests/mock_cel_expression.h new file mode 100644 index 000000000..ba0ab2041 --- /dev/null +++ b/eval/tests/mock_cel_expression.h @@ -0,0 +1,44 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ + +#include + +#include "gmock/gmock.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expression.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +class MockCelExpression : public CelExpression { + public: + MOCK_CONST_METHOD1(InitializeState, + std::unique_ptr(google::protobuf::Arena* arena)); + + MOCK_CONST_METHOD2( + Evaluate, ::cel_base::StatusOr(const BaseActivation& activation, + google::protobuf::Arena* arena)); + + MOCK_CONST_METHOD2( + Evaluate, ::cel_base::StatusOr(const BaseActivation& activation, + CelEvaluationState* state)); + + MOCK_CONST_METHOD3( + Trace, ::cel_base::StatusOr(const BaseActivation& activation, + google::protobuf::Arena* arena, + CelEvaluationListener callback)); + + MOCK_CONST_METHOD3( + Trace, ::cel_base::StatusOr(const BaseActivation& activation, + CelEvaluationState* state, + CelEvaluationListener callback)); +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc new file mode 100644 index 000000000..0a5edefe3 --- /dev/null +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -0,0 +1,305 @@ +#include + +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/string_view.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_set.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +using ::google::protobuf::Arena; +using testing::ElementsAre; + +// var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') +constexpr char kExprTextproto[] = R"pb( + id: 13 + call_expr { + function: "_||_" + args { + id: 6 + call_expr { + function: "_&&_" + args { + id: 2 + call_expr { + function: "_>_" + args { + id: 1 + ident_expr { name: "var1" } + } + args { + id: 3 + const_expr { int64_value: 3 } + } + } + } + args { + id: 4 + call_expr { + function: "F1" + args { + id: 5 + const_expr { string_value: "arg1" } + } + } + } + } + } + args { + id: 12 + call_expr { + function: "_&&_" + args { + id: 8 + call_expr { + function: "_>_" + args { + id: 7 + ident_expr { name: "var2" } + } + args { + id: 9 + const_expr { int64_value: 3 } + } + } + } + args { + id: 10 + call_expr { + function: "F2" + args { + id: 11 + const_expr { string_value: "arg2" } + } + } + } + } + } + })pb"; + +enum class FunctionResponse { kUnknown, kTrue, kFalse }; + +CelFunctionDescriptor CreateDescriptor(absl::string_view name) { + return CelFunctionDescriptor(std::string(name), false, + {CelValue::Type::kString}); +} + +class FunctionImpl : public CelFunction { + public: + FunctionImpl(absl::string_view name, FunctionResponse response) + : CelFunction(CreateDescriptor(name)), response_(response) {} + + absl::Status Evaluate(absl::Span arguments, CelValue* result, + Arena* arena) const override { + switch (response_) { + case FunctionResponse::kUnknown: + *result = CreateUnknownFunctionResultError(arena, "help message"); + break; + case FunctionResponse::kTrue: + *result = CelValue::CreateBool(true); + break; + case FunctionResponse::kFalse: + *result = CelValue::CreateBool(false); + break; + } + return absl::OkStatus(); + } + + private: + FunctionResponse response_; +}; + +// Text fixture for unknowns. Holds on to state needed for execution to work +// correctly. +class UnknownsTest : public testing::Test { + public: + void PrepareBuilder(UnknownProcessingOptions opts) { + InterpreterOptions options; + options.unknown_processing = opts; + builder_ = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder_->GetRegistry())); + ASSERT_OK( + builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F1"))); + ASSERT_OK( + builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F2"))); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kExprTextproto, &expr_)) + << "error parsing expr"; + } + + protected: + Arena arena_; + Activation activation_; + std::unique_ptr builder_; + google::api::expr::v1alpha1::Expr expr_; +}; + +MATCHER_P2(FunctionCallIs, fn_name, fn_arg, "") { + const UnknownFunctionResult* result = arg; + return result->arguments().size() == 1 && result->arguments()[0].IsString() && + result->arguments()[0].StringOrDie().value() == fn_arg && + result->descriptor().name() == fn_name; +} + +MATCHER_P(AttributeIs, attr, "") { + const CelAttribute* result = arg; + return result->variable().ident_expr().name() == attr; +} + +TEST_F(UnknownsTest, NoUnknowns) { + PrepareBuilder(UnknownProcessingOptions::kDisabled); + // activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", + // {})}); + activation_.InsertValue("var1", CelValue::CreateInt64(3)); + activation_.InsertValue("var2", CelValue::CreateInt64(5)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kFalse))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kTrue))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsBool()); + EXPECT_TRUE(response.BoolOrDie()); +} + +TEST_F(UnknownsTest, UnknownAttributes) { + PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); + activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); + activation_.InsertValue("var2", CelValue::CreateInt64(3)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kTrue))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kFalse))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsUnknownSet()); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + ElementsAre(AttributeIs("var1"))); +} + +TEST_F(UnknownsTest, UnknownAttributesPruning) { + PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); + activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); + activation_.InsertValue("var2", CelValue::CreateInt64(5)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kTrue))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kTrue))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsBool()); + EXPECT_TRUE(response.BoolOrDie()); +} + +TEST_F(UnknownsTest, UnknownFunctionsWithoutOptionError) { + PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); + activation_.InsertValue("var1", CelValue::CreateInt64(5)); + activation_.InsertValue("var2", CelValue::CreateInt64(3)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kUnknown))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kFalse))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsError()); + EXPECT_EQ(response.ErrorOrDie()->code(), absl::StatusCode::kUnavailable); +} + +TEST_F(UnknownsTest, UnknownFunctions) { + PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); + activation_.InsertValue("var1", CelValue::CreateInt64(5)); + activation_.InsertValue("var2", CelValue::CreateInt64(5)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kUnknown))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kFalse))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + EXPECT_THAT(response.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(FunctionCallIs("F1", "arg1"))); +} + +TEST_F(UnknownsTest, UnknownsMerge) { + PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); + activation_.InsertValue("var1", CelValue::CreateInt64(5)); + activation_.set_unknown_attribute_patterns({CelAttributePattern("var2", {})}); + + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kUnknown))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kTrue))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + EXPECT_THAT(response.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(FunctionCallIs("F1", "arg1"))); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + ElementsAre(AttributeIs("var2"))); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/internal/cel_printer.h b/internal/cel_printer.h index d86729a07..11a696390 100644 --- a/internal/cel_printer.h +++ b/internal/cel_printer.h @@ -35,9 +35,7 @@ struct RawString { */ struct ScalarPrinter { inline std::string operator()(std::nullptr_t) { return "null"; } - inline std::string operator()(bool value) { - return value ? "true" : "false"; - } + inline std::string operator()(bool value) { return value ? "true" : "false"; } std::string operator()(absl::Time value); std::string operator()(absl::Duration value); @@ -83,8 +81,8 @@ struct ForwardingPrinter { // If the value defines a ToDebugString function, call it. template - specialize_ifd().ToDebugString())> operator()( - T&& value) { + specialize_ifd().ToDebugString())> + operator()(T&& value) { return std::string(value.ToDebugString()); } }; diff --git a/internal/hash_util.cc b/internal/hash_util.cc index c3b914bf7..c44bd347e 100644 --- a/internal/hash_util.cc +++ b/internal/hash_util.cc @@ -7,7 +7,9 @@ namespace api { namespace expr { namespace internal { -std::size_t HashImpl(const std::string& value, specialize) { return StdHash(value); } +std::size_t HashImpl(const std::string& value, specialize) { + return StdHash(value); +} std::size_t HashImpl(const google::rpc::Status& value, specialize) { std::size_t hash = Hash(value.code()); diff --git a/parser/BUILD b/parser/BUILD index 3b634e599..367694156 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -27,7 +27,8 @@ cc_library( ":cel_cc_parser", ":macro", ":visitor", - "//base:status", + "//base:status_macros", + "//base:statusor", "@antlr4_runtimes//:cpp", "@com_google_absl//absl/types:optional", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/parser/parser.cc b/parser/parser.cc index 20fae9c04..bed72a4ca 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -5,7 +5,6 @@ #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" #include "parser/visitor.h" #include "antlr4-runtime.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -47,15 +46,15 @@ cel_base::StatusOr ParseWithMacros(const std::string& expression, try { root = parser.start(); } catch (ParseCancellationException& e) { - return cel_base::CancelledError(e.what()); + return absl::CancelledError(e.what()); } catch (std::exception& e) { - return cel_base::AbortedError(e.what()); + return absl::AbortedError(e.what()); } Expr expr = visitor.visit(root).as(); if (visitor.hasErrored()) { - return cel_base::InvalidArgumentError(visitor.errorMessage()); + return absl::InvalidArgumentError(visitor.errorMessage()); } // root is deleted as part of the parser context diff --git a/parser/parser_test.cc b/parser/parser_test.cc index a61e309d9..1467c95f4 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1,5 +1,6 @@ #include "parser/parser.h" +#include #include #include "gmock/gmock.h" @@ -1132,6 +1133,7 @@ class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); + /* ::testing::internal::ColoredPrintf(::testing::internal::COLOR_GREEN, "[ ]"); ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, @@ -1141,6 +1143,7 @@ TEST_P(ExpressionTest, Parse) { ::testing::internal::ColoredPrintf( ::testing::internal::COLOR_DEFAULT, "%s\n", !test_info.E.empty() ? " (error expected)" : ""); + */ auto result = Parse(test_info.I); if (test_info.E.empty()) { diff --git a/protoutil/type_registry.cc b/protoutil/type_registry.cc index 5e82bb941..ba5c7854a 100644 --- a/protoutil/type_registry.cc +++ b/protoutil/type_registry.cc @@ -100,8 +100,9 @@ class ProtoStrList final : public BaseProtoRefList { common::Value Get(std::size_t index) const override { std::string scratch; - const std::string& value = msg_->GetReflection()->GetRepeatedStringReference( - *msg_, field_, index, &scratch); + const std::string& value = + msg_->GetReflection()->GetRepeatedStringReference(*msg_, field_, index, + &scratch); if (&value == &scratch) { return common::Value::From(value); } diff --git a/testutil/test_data_io.cc b/testutil/test_data_io.cc index 9566b942e..9e486c00d 100644 --- a/testutil/test_data_io.cc +++ b/testutil/test_data_io.cc @@ -67,8 +67,8 @@ std::unique_ptr OpenForWrite( return nullptr; } } -std::string GetTestCaseFileName(absl::string_view dir, absl::string_view test_name, - bool binary) { +std::string GetTestCaseFileName(absl::string_view dir, + absl::string_view test_name, bool binary) { return absl::StrCat(dir, test_name, binary ? kBinaryPbExt : kTextPbExt); } diff --git a/testutil/test_data_util.cc b/testutil/test_data_util.cc index 06d9f8c2f..3a1a10928 100644 --- a/testutil/test_data_util.cc +++ b/testutil/test_data_util.cc @@ -121,7 +121,7 @@ bool rep_as(F value) { template std::string MakeName(absl::string_view type, T&& value, - absl::string_view name = "") { + absl::string_view name = "") { if (name.empty()) { return absl::StrCat(type, "(", value, ")"); }