diff --git a/base/BUILD b/base/BUILD index 8b92d6fc2..2ddce9ff0 100644 --- a/base/BUILD +++ b/base/BUILD @@ -5,26 +5,26 @@ package( ) cc_library( - name = "status", + name = "statusor", srcs = [ - "canonical_errors.cc", - "status.cc", "statusor.cc", ], hdrs = [ - "canonical_errors.h", - "status.h", "statusor.h", "statusor_internals.h", ], copts = ["-std=c++14"], deps = [ - "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/meta:type_traits", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/status", ], ) + +cc_library( + name = "status_macros", + hdrs = [ + "status_macros.h", + ], + copts = ["-std=c++14"], +) diff --git a/base/canonical_errors.cc b/base/canonical_errors.cc deleted file mode 100644 index 4588e0f88..000000000 --- a/base/canonical_errors.cc +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "base/canonical_errors.h" - -namespace cel_base { - -Status AbortedError(absl::string_view message) { - return Status(ABORTED, message); -} - -Status AlreadyExistsError(absl::string_view message) { - return Status(ALREADY_EXISTS, message); -} - -Status CancelledError(absl::string_view message) { - return Status(CANCELLED, message); -} - -Status DataLossError(absl::string_view message) { - return Status(DATA_LOSS, message); -} - -Status DeadlineExceededError(absl::string_view message) { - return Status(DEADLINE_EXCEEDED, message); -} - -Status FailedPreconditionError(absl::string_view message) { - return Status(FAILED_PRECONDITION, message); -} - -Status InternalError(absl::string_view message) { - return Status(INTERNAL, message); -} - -Status InvalidArgumentError(absl::string_view message) { - return Status(INVALID_ARGUMENT, message); -} - -Status NotFoundError(absl::string_view message) { - return Status(NOT_FOUND, message); -} - -Status OutOfRangeError(absl::string_view message) { - return Status(OUT_OF_RANGE, message); -} - -Status PermissionDeniedError(absl::string_view message) { - return Status(PERMISSION_DENIED, message); -} - -Status ResourceExhaustedError(absl::string_view message) { - return Status(RESOURCE_EXHAUSTED, message); -} - -Status UnauthenticatedError(absl::string_view message) { - return Status(UNAUTHENTICATED, message); -} - -Status UnavailableError(absl::string_view message) { - return Status(UNAVAILABLE, message); -} - -Status UnimplementedError(absl::string_view message) { - return Status(UNIMPLEMENTED, message); -} - -Status UnknownError(absl::string_view message) { - return Status(UNKNOWN, message); -} - -bool IsAborted(const Status& status) { return status.code() == ABORTED; } - -bool IsAlreadyExists(const Status& status) { - return status.code() == ALREADY_EXISTS; -} - -bool IsCancelled(const Status& status) { return status.code() == CANCELLED; } - -bool IsDataLoss(const Status& status) { return status.code() == DATA_LOSS; } - -bool IsDeadlineExceeded(const Status& status) { - return status.code() == DEADLINE_EXCEEDED; -} - -bool IsFailedPrecondition(const Status& status) { - return status.code() == FAILED_PRECONDITION; -} - -bool IsInternal(const Status& status) { return status.code() == INTERNAL; } - -bool IsInvalidArgument(const Status& status) { - return status.code() == INVALID_ARGUMENT; -} - -bool IsNotFound(const Status& status) { return status.code() == NOT_FOUND; } - -bool IsOutOfRange(const Status& status) { - return status.code() == OUT_OF_RANGE; -} - -bool IsPermissionDenied(const Status& status) { - return status.code() == PERMISSION_DENIED; -} - -bool IsResourceExhausted(const Status& status) { - return status.code() == RESOURCE_EXHAUSTED; -} - -bool IsUnauthenticated(const Status& status) { - return status.code() == UNAUTHENTICATED; -} - -bool IsUnavailable(const Status& status) { - return status.code() == UNAVAILABLE; -} - -bool IsUnimplemented(const Status& status) { - return status.code() == UNIMPLEMENTED; -} - -bool IsUnknown(const Status& status) { return status.code() == UNKNOWN; } - -} // namespace cel_base diff --git a/base/canonical_errors.h b/base/canonical_errors.h deleted file mode 100644 index a498f8172..000000000 --- a/base/canonical_errors.h +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef THIRD_PARTY_CEL_CPP_BASE_CANONICAL_ERRORS_H_ -#define THIRD_PARTY_CEL_CPP_BASE_CANONICAL_ERRORS_H_ - -// This file declares a set of functions for working with Status objects from -// the canonical errors. There are functions to easily generate such -// status object and function for classifying them. -#include "absl/base/attributes.h" -#include "absl/base/macros.h" -#include "absl/strings/string_view.h" -#include "base/status.h" - -namespace cel_base { - -// Each of the functions below creates a canonical error with the given -// message. The error code of the returned status object matches the name of -// the function. -Status AbortedError(absl::string_view message); -Status AlreadyExistsError(absl::string_view message); -Status CancelledError(absl::string_view message); -Status DataLossError(absl::string_view message); -Status DeadlineExceededError(absl::string_view message); -Status FailedPreconditionError(absl::string_view message); -Status InternalError(absl::string_view message); -Status InvalidArgumentError(absl::string_view message); -Status NotFoundError(absl::string_view message); -Status OutOfRangeError(absl::string_view message); -Status PermissionDeniedError(absl::string_view message); -Status ResourceExhaustedError(absl::string_view message); -Status UnauthenticatedError(absl::string_view message); -Status UnavailableError(absl::string_view message); -Status UnimplementedError(absl::string_view message); -Status UnknownError(absl::string_view message); - -// Each of the functions below returns true if the given status matches the -// canonical error code implied by the function's name. -ABSL_MUST_USE_RESULT bool IsAborted(const Status& status); -ABSL_MUST_USE_RESULT bool IsAlreadyExists(const Status& status); -ABSL_MUST_USE_RESULT bool IsCancelled(const Status& status); -ABSL_MUST_USE_RESULT bool IsDataLoss(const Status& status); -ABSL_MUST_USE_RESULT bool IsDeadlineExceeded(const Status& status); -ABSL_MUST_USE_RESULT bool IsFailedPrecondition(const Status& status); -ABSL_MUST_USE_RESULT bool IsInternal(const Status& status); -ABSL_MUST_USE_RESULT bool IsInvalidArgument(const Status& status); -ABSL_MUST_USE_RESULT bool IsNotFound(const Status& status); -ABSL_MUST_USE_RESULT bool IsOutOfRange(const Status& status); -ABSL_MUST_USE_RESULT bool IsPermissionDenied(const Status& status); -ABSL_MUST_USE_RESULT bool IsResourceExhausted(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnauthenticated(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnavailable(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnimplemented(const Status& status); -ABSL_MUST_USE_RESULT bool IsUnknown(const Status& status); - -} // namespace cel_base - -#endif // THIRD_PARTY_CEL_CPP_BASE_CANONICAL_ERRORS_H_ diff --git a/base/status.cc b/base/status.cc deleted file mode 100644 index ee59645d6..000000000 --- a/base/status.cc +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "base/status.h" - -#include "absl/strings/str_cat.h" -#include "absl/types/optional.h" - -namespace cel_base { - -std::string StatusCodeToString(StatusCode code) { - switch (code) { - case StatusCode::kOk: - return "OK"; - case StatusCode::kCancelled: - return "CANCELLED"; - case StatusCode::kUnknown: - return "UNKNOWN"; - case StatusCode::kInvalidArgument: - return "INVALID_ARGUMENT"; - case StatusCode::kDeadlineExceeded: - return "DEADLINE_EXCEEDED"; - case StatusCode::kNotFound: - return "NOT_FOUND"; - case StatusCode::kAlreadyExists: - return "ALREADY_EXISTS"; - case StatusCode::kPermissionDenied: - return "PERMISSION_DENIED"; - case StatusCode::kUnauthenticated: - return "UNAUTHENTICATED"; - case StatusCode::kResourceExhausted: - return "RESOURCE_EXHAUSTED"; - case StatusCode::kFailedPrecondition: - return "FAILED_PRECONDITION"; - case StatusCode::kAborted: - return "ABORTED"; - case StatusCode::kOutOfRange: - return "OUT_OF_RANGE"; - case StatusCode::kUnimplemented: - return "UNIMPLEMENTED"; - case StatusCode::kInternal: - return "INTERNAL"; - case StatusCode::kUnavailable: - return "UNAVAILABLE"; - case StatusCode::kDataLoss: - return "DATA_LOSS"; - default: - return ""; - } -} - -std::ostream& operator<<(std::ostream& os, StatusCode code) { - return os << StatusCodeToString(code); -} - -Status::Status(StatusCode code, absl::string_view message) - : code_(code), message_(code == StatusCode::kOk ? "" : message) {} - -std::string Status::ToString() const { - return ok() ? "OK" : absl::StrCat(StatusCodeToString(code()), ": ", message_); -} - -void Status::SetPayload(absl::string_view type_url, const StatusCord& payload) { - if (!ok()) { - payload_.try_emplace(std::string(type_url), payload); - } -} - -absl::optional Status::GetPayload( - absl::string_view type_url) const { - auto it = payload_.find(std::string(type_url)); - if (it == payload_.end()) return absl::nullopt; - return it->second; -} - -void Status::ErasePayload(absl::string_view type_url) { - payload_.erase(std::string(type_url)); -} - -std::ostream& operator<<(std::ostream& os, const Status& x) { - return os << x.ToString(); -} - -} // namespace cel_base diff --git a/base/status.h b/base/status.h deleted file mode 100644 index 2e215339e..000000000 --- a/base/status.h +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright 2019 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef THIRD_PARTY_CEL_CPP_BASE_STATUS_H_ -#define THIRD_PARTY_CEL_CPP_BASE_STATUS_H_ - -#include -#include -#include - -#include "absl/base/attributes.h" -#include "absl/container/node_hash_map.h" -#include "absl/strings/string_view.h" -#include "absl/types/optional.h" - -namespace cel_base { - -// replace with cel_base::StatusCord once available. -using StatusCord = std::string; - -enum class StatusCode { - kOk = 0, - kCancelled = 1, - kUnknown = 2, - kInvalidArgument = 3, - kDeadlineExceeded = 4, - kNotFound = 5, - kAlreadyExists = 6, - kPermissionDenied = 7, - kResourceExhausted = 8, - kFailedPrecondition = 9, - kAborted = 10, - kOutOfRange = 11, - kUnimplemented = 12, - kInternal = 13, - kUnavailable = 14, - kDataLoss = 15, - kUnauthenticated = 16, - kDoNotUseReservedForFutureExpansionUseDefaultInSwitchInstead_ = 20 -}; - -std::string StatusCodeToString(StatusCode e); - -std::ostream& operator<<(std::ostream& os, StatusCode code); - -// Handle both for now. This is meant to be _very short lived. Once internal -// code can safely use the new naming, we will switch that that and drop this. -constexpr StatusCode OK = StatusCode::kOk; -constexpr StatusCode CANCELLED = StatusCode::kCancelled; -constexpr StatusCode UNKNOWN = StatusCode::kUnknown; -constexpr StatusCode INVALID_ARGUMENT = StatusCode::kInvalidArgument; -constexpr StatusCode DEADLINE_EXCEEDED = StatusCode::kDeadlineExceeded; -constexpr StatusCode NOT_FOUND = StatusCode::kNotFound; -constexpr StatusCode ALREADY_EXISTS = StatusCode::kAlreadyExists; -constexpr StatusCode PERMISSION_DENIED = StatusCode::kPermissionDenied; -constexpr StatusCode UNAUTHENTICATED = StatusCode::kUnauthenticated; -constexpr StatusCode RESOURCE_EXHAUSTED = StatusCode::kResourceExhausted; -constexpr StatusCode FAILED_PRECONDITION = StatusCode::kFailedPrecondition; -constexpr StatusCode ABORTED = StatusCode::kAborted; -constexpr StatusCode OUT_OF_RANGE = StatusCode::kOutOfRange; -constexpr StatusCode UNIMPLEMENTED = StatusCode::kUnimplemented; -constexpr StatusCode INTERNAL = StatusCode::kInternal; -constexpr StatusCode UNAVAILABLE = StatusCode::kUnavailable; -constexpr StatusCode DATA_LOSS = StatusCode::kDataLoss; - -class ABSL_MUST_USE_RESULT Status; - -class Status final { - public: - // Builds an OK Status. - Status() = default; - - // Constructs a Status object containing a status code and message. - // If `code == StatusCode::kOk`, `msg` is ignored and an object identical to - // an OK status is constructed. - Status(StatusCode code, absl::string_view message); - - // Return the error message (if any). - absl::string_view message() const { return message_; } - - // Returns true if the Status is OK. - ABSL_MUST_USE_RESULT bool ok() const; - - // Deprecated. Use code(). - int error_code() const; - - // Deprecated. Use message(). - std::string error_message() const; - - // Deprecated. Use code(). - StatusCode CanonicalCode() const; - - // If "ok()", does nothing. Else adds the given `payload` specified, by - // `type_url` as an additional payload. - void SetPayload(absl::string_view type_url, const StatusCord& payload); - - bool operator==(const Status& x) const; - bool operator!=(const Status& x) const; - - void Update(const Status& rhs) { - if (ok()) *this = rhs; - } - - // Return a combination of the error code name and message. - // Note, no guarantees are made as to the exact nature of the returned string. - // Subject to change at any time. - std::string ToString() const; - - // Deprecated. Just returns self. - Status ToCanonical() const; - - // Ignores any errors. This method does nothing except potentially suppress - // complaints from any tools that are checking that errors are not dropped on - // the floor. - void IgnoreError() const; - - // Returns the stored status code. - StatusCode code() const { return code_; } - - // Retrieve a single value associated with `type_url`. Returns absl::nullopt - // if no value is associated with `type_url`. - absl::optional GetPayload(const absl::string_view type_url) const; - - // Erase the payload associated with `type_url`, if present. - void ErasePayload(absl::string_view type_url); - - void ForEachPayload( - const std::function& visitor) - const; - - private: - StatusCode code_ = StatusCode::kOk; - std::string message_; - // Structured error payload. String is a 'type_url' for example, a proto - // descriptor full name. - absl::node_hash_map payload_; -}; - -inline bool Status::ok() const { return StatusCode::kOk == code_; } - -inline int Status::error_code() const { return static_cast(code()); } - -inline std::string Status::error_message() const { - return std::string(message()); -} - -inline StatusCode Status::CanonicalCode() const { return code(); } - -inline Status Status::ToCanonical() const { return *this; } - -inline bool Status::operator==(const Status& x) const { - return (code_ == x.code_) && (message_ == x.message_) && - (payload_ == x.payload_); -} - -inline bool Status::operator!=(const Status& x) const { return !(*this == x); } - -inline void Status::IgnoreError() const { - // no-op -} - -inline void Status::ForEachPayload( - const std::function& visitor) - const { - for (auto it = payload_.begin(); it != payload_.end(); ++it) { - visitor(it->first, it->second); - } -} - -// Prints a human-readable representation of 'x' to 'os'. -std::ostream& operator<<(std::ostream& os, const Status& x); - -// Constructs an OK status object. -inline Status OkStatus() { return Status(); } - -// This is better than GOOGLE_CHECK((val).ok()) because the embedded -// error string gets printed by the CHECK_EQ. -#define CHECK_OK(val) \ - ABSL_RAW_CHECK(val == ::cel_base::OkStatus(), "Status not OK") - -} // namespace cel_base - -#endif // THIRD_PARTY_CEL_CPP_BASE_STATUS_H_ diff --git a/base/status_macros.h b/base/status_macros.h new file mode 100644 index 000000000..675fe656a --- /dev/null +++ b/base/status_macros.h @@ -0,0 +1,53 @@ +/* + * Copyright 2020 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ +#define THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ + +#include // for use with down_cast<> + +#include + +// Early-returns the status if it is in error; otherwise, proceeds. +// +// The argument expression is guaranteed to be evaluated exactly once. +#define RETURN_IF_ERROR(__status) \ + do { \ + auto _status = __status; \ + if (!_status.ok()) { \ + return _status; \ + } \ + } while (false) + +template // use like this: down_cast(foo); +inline To down_cast(From* f) { // so we only accept pointers + static_assert( + (std::is_base_of::type>::value), + "target type not derived from source type"); + + // We skip the assert and hence the dynamic_cast if RTTI is disabled. +#if !defined(__GNUC__) || defined(__GXX_RTTI) + // Uses RTTI in dbg and fastbuild. asserts are disabled in opt builds. + assert(f == nullptr || dynamic_cast(f) != nullptr); +#endif // !defined(__GNUC__) || defined(__GXX_RTTI) + + return static_cast(f); +} + +#define ASSERT_OK(expression) ASSERT_TRUE(expression.ok()) +#define EXPECT_OK(expression) EXPECT_TRUE(expression.ok()) + +#endif // THIRD_PARTY_CEL_CPP_BASE_STATUS_MACROS_H_ diff --git a/base/statusor.cc b/base/statusor.cc index 95386e056..5215922dc 100644 --- a/base/statusor.cc +++ b/base/statusor.cc @@ -18,24 +18,18 @@ #include -#include "base/status.h" - namespace cel_base { namespace statusor_internal { -void Helper::HandleInvalidStatusCtorArg(Status* status) { +void Helper::HandleInvalidStatusCtorArg(absl::Status* status) { const char* kMessage = "An OK status is not a valid constructor argument to StatusOr"; - ABSL_RAW_CHECK(false, kMessage); // In optimized builds, we will fall back to ::util::error::INTERNAL. - *status = Status(StatusCode::kInternal, kMessage); + *status = absl::Status(absl::StatusCode::kInternal, kMessage); } -void Helper::Crash(const Status&) { - ABSL_RAW_CHECK(false, "Attempting to fetch value instead of handling error"); - abort(); -} +void Helper::Crash(const absl::Status&) { abort(); } } // namespace statusor_internal diff --git a/base/statusor.h b/base/statusor.h index 0a26bade2..26f43d208 100644 --- a/base/statusor.h +++ b/base/statusor.h @@ -70,7 +70,6 @@ #include "absl/base/attributes.h" #include "absl/base/macros.h" -#include "base/status.h" #include "base/statusor_internals.h" namespace cel_base { @@ -144,8 +143,8 @@ class StatusOr : private statusor_internal::StatusOrData, // REQUIRES: !status.ok(). This requirement is DCHECKed. // In optimized builds, passing OkStatus() here will have the effect // of passing INTERNAL as a fallback. - StatusOr(const Status& status); - StatusOr& operator=(const Status& status); + StatusOr(const absl::Status& status); + StatusOr& operator=(const absl::Status& status); // Similar to the `const T&` overload. // @@ -153,8 +152,8 @@ class StatusOr : private statusor_internal::StatusOrData, StatusOr(T&& value); // RValue versions of the operations declared above. - StatusOr(Status&& status); - StatusOr& operator=(Status&& status); + StatusOr(absl::Status&& status); + StatusOr& operator=(absl::Status&& status); // Returns this->ok() explicit operator bool() const { return ok(); } @@ -164,8 +163,8 @@ class StatusOr : private statusor_internal::StatusOrData, // Returns a reference to our status. If this contains a T, then // returns OkStatus(). - const Status& status() const&; - Status status() &&; + const absl::Status& status() const&; + absl::Status status() &&; // Returns a reference to our current value, or CHECK-fails if !this->ok(). If // you have already checked the status using this->ok() or operator bool(), @@ -237,16 +236,16 @@ class StatusOr : private statusor_internal::StatusOrData, // Implementation details for StatusOr template -StatusOr::StatusOr() : Base(Status(UNKNOWN, "")) {} +StatusOr::StatusOr() : Base(absl::Status(absl::StatusCode::kUnknown, "")) {} template StatusOr::StatusOr(const T& value) : Base(value) {} template -StatusOr::StatusOr(const Status& status) : Base(status) {} +StatusOr::StatusOr(const absl::Status& status) : Base(status) {} template -StatusOr& StatusOr::operator=(const Status& status) { +StatusOr& StatusOr::operator=(const absl::Status& status) { this->Assign(status); return *this; } @@ -255,10 +254,10 @@ template StatusOr::StatusOr(T&& value) : Base(std::move(value)) {} template -StatusOr::StatusOr(Status&& status) : Base(std::move(status)) {} +StatusOr::StatusOr(absl::Status&& status) : Base(std::move(status)) {} template -StatusOr& StatusOr::operator=(Status&& status) { +StatusOr& StatusOr::operator=(absl::Status&& status) { this->Assign(std::move(status)); return *this; } @@ -295,12 +294,12 @@ inline StatusOr& StatusOr::operator=(StatusOr&& other) { } template -const Status& StatusOr::status() const& { +const absl::Status& StatusOr::status() const& { return this->status_; } template -Status StatusOr::status() && { - return ok() ? OkStatus() : std::move(this->status_); +absl::Status StatusOr::status() && { + return ok() ? absl::OkStatus() : std::move(this->status_); } template diff --git a/base/statusor_internals.h b/base/statusor_internals.h index fab2bec94..a524f7848 100644 --- a/base/statusor_internals.h +++ b/base/statusor_internals.h @@ -23,7 +23,7 @@ #include "absl/base/attributes.h" #include "absl/meta/type_traits.h" -#include "base/status.h" +#include "absl/status/status.h" namespace cel_base { @@ -32,8 +32,8 @@ namespace statusor_internal { class Helper { public: // Move type-agnostic error handling to the .cc. - static void HandleInvalidStatusCtorArg(Status*); - ABSL_ATTRIBUTE_NORETURN static void Crash(const Status& status); + static void HandleInvalidStatusCtorArg(absl::Status*); + ABSL_ATTRIBUTE_NORETURN static void Crash(const absl::Status& status); }; // Construct an instance of T in `p` through placement new, passing Args... to @@ -100,10 +100,10 @@ class StatusOrData { explicit StatusOrData(const T& value) : data_(value) { MakeStatus(); } explicit StatusOrData(T&& value) : data_(std::move(value)) { MakeStatus(); } - explicit StatusOrData(const Status& status) : status_(status) { + explicit StatusOrData(const absl::Status& status) : status_(status) { EnsureNotOk(); } - explicit StatusOrData(Status&& status) : status_(std::move(status)) { + explicit StatusOrData(absl::Status&& status) : status_(std::move(status)) { EnsureNotOk(); } @@ -140,7 +140,7 @@ class StatusOrData { MakeValue(value); } else { MakeValue(value); - status_ = OkStatus(); + status_ = absl::OkStatus(); } } @@ -150,17 +150,17 @@ class StatusOrData { MakeValue(std::move(value)); } else { MakeValue(std::move(value)); - status_ = OkStatus(); + status_ = absl::OkStatus(); } } - void Assign(const Status& status) { + void Assign(const absl::Status& status) { Clear(); status_ = status; EnsureNotOk(); } - void Assign(Status&& status) { + void Assign(absl::Status&& status) { Clear(); status_ = std::move(status); EnsureNotOk(); @@ -175,7 +175,7 @@ class StatusOrData { // Eg. in the copy constructor we use the default constructor of // Status in the ok() path to avoid an extra Ref call. union { - Status status_; + absl::Status status_; }; // data_ is active iff status_.ok()==true @@ -210,8 +210,8 @@ class StatusOrData { // argument. template void MakeStatus(Args&&... args) { - statusor_internal::PlacementNew(&status_, - std::forward(args)...); + statusor_internal::PlacementNew(&status_, + std::forward(args)...); } }; diff --git a/common/escaping.cc b/common/escaping.cc index b7f8b16c4..39aba6d0a 100644 --- a/common/escaping.cc +++ b/common/escaping.cc @@ -27,11 +27,11 @@ inline std::pair unhex(char c) { // least 4 bytes long. Return the number of bytes written. inline int get_utf8(absl::string_view s, char* buffer) { buffer[0] = s[0]; - if (s[0] < 0x80 || s.size() < 2) return 1; + if (static_cast(s[0]) < 0x80 || s.size() < 2) return 1; buffer[1] = s[1]; - if (s[0] < 0xE0 || s.size() < 3) return 2; + if (static_cast(s[0]) < 0xE0 || s.size() < 3) return 2; buffer[2] = s[2]; - if (s[0] < 0xF0 || s.size() < 4) return 3; + if (static_cast(s[0]) < 0xF0 || s.size() < 4) return 3; buffer[3] = s[3]; return 4; } @@ -84,7 +84,7 @@ inline std::tuple unescape_char( char c = s[0]; // 1. Character is not an escape sequence. - if (c >= 0x80 && !is_bytes) { + if (static_cast(c) >= 0x80 && !is_bytes) { char tmp[5]; int len = get_utf8(s, tmp); tmp[len] = '\0'; diff --git a/common/escaping_test.cc b/common/escaping_test.cc index 85cdc7f8b..8275b48ec 100644 --- a/common/escaping_test.cc +++ b/common/escaping_test.cc @@ -66,6 +66,7 @@ class UnescapeTest : public testing::TestWithParam {}; TEST_P(UnescapeTest, Unescape) { const TestInfo& test_info = GetParam(); + /* ::testing::internal::ColoredPrintf(::testing::internal::COLOR_GREEN, "[ ]"); ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, @@ -82,6 +83,7 @@ TEST_P(UnescapeTest, Unescape) { ::testing::internal::ColoredPrintf(::testing::internal::COLOR_YELLOW, " Expecting ERROR\n"); } + */ auto result = unescape(test_info.I, test_info.is_bytes); if (test_info.O == EXPECT_ERROR) { diff --git a/common/operators.cc b/common/operators.cc index 803409a77..669469c9a 100644 --- a/common/operators.cc +++ b/common/operators.cc @@ -14,85 +14,91 @@ namespace { // Expr to textual mapping, e.g., from "_&&_" to "&&". const std::map& UnaryOperators() { - static std::shared_ptr> unaries_map = [&]() { - auto u = - std::make_shared>(std::map{ - {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}); - return u; - }(); + static std::shared_ptr> unaries_map = + [&]() { + auto u = std::make_shared>( + std::map{ + {CelOperator::NEGATE, "-"}, {CelOperator::LOGICAL_NOT, "!"}}); + return u; + }(); return *unaries_map; } const std::map& BinaryOperators() { - static std::shared_ptr> binops_map = [&]() { - auto c = std::make_shared>( - std::map{{CelOperator::LOGICAL_OR, "||"}, - {CelOperator::LOGICAL_AND, "&&"}, - {CelOperator::LESS_EQUALS, "<="}, - {CelOperator::LESS, "<"}, - {CelOperator::GREATER_EQUALS, ">="}, - {CelOperator::GREATER, ">"}, - {CelOperator::EQUALS, "=="}, - {CelOperator::NOT_EQUALS, "!="}, - {CelOperator::IN_DEPRECATED, "in"}, - {CelOperator::IN, "in"}, - {CelOperator::ADD, "+"}, - {CelOperator::SUBTRACT, "-"}, - {CelOperator::MULTIPLY, "*"}, - {CelOperator::DIVIDE, "/"}, - {CelOperator::MODULO, "%"}}); - return c; - }(); + static std::shared_ptr> binops_map = + [&]() { + auto c = std::make_shared>( + std::map{ + {CelOperator::LOGICAL_OR, "||"}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LESS, "<"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::GREATER, ">"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::IN, "in"}, + {CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}}); + return c; + }(); return *binops_map; } const std::map& ReverseOperators() { - static std::shared_ptr> operators_map = [&]() { - auto c = - std::make_shared>(std::map{ - {"+", CelOperator::ADD}, - {"-", CelOperator::SUBTRACT}, - {"*", CelOperator::MULTIPLY}, - {"/", CelOperator::DIVIDE}, - {"%", CelOperator::MODULO}, - {"==", CelOperator::EQUALS}, - {"!=", CelOperator::NOT_EQUALS}, - {">", CelOperator::GREATER}, - {">=", CelOperator::GREATER_EQUALS}, - {"<", CelOperator::LESS}, - {"<=", CelOperator::LESS_EQUALS}, - {"&&", CelOperator::LOGICAL_AND}, - {"!", CelOperator::LOGICAL_NOT}, - {"||", CelOperator::LOGICAL_OR}, - {"in", CelOperator::IN}, - }); - return c; - }(); + static std::shared_ptr> operators_map = + [&]() { + auto c = std::make_shared>( + std::map{ + {"+", CelOperator::ADD}, + {"-", CelOperator::SUBTRACT}, + {"*", CelOperator::MULTIPLY}, + {"/", CelOperator::DIVIDE}, + {"%", CelOperator::MODULO}, + {"==", CelOperator::EQUALS}, + {"!=", CelOperator::NOT_EQUALS}, + {">", CelOperator::GREATER}, + {">=", CelOperator::GREATER_EQUALS}, + {"<", CelOperator::LESS}, + {"<=", CelOperator::LESS_EQUALS}, + {"&&", CelOperator::LOGICAL_AND}, + {"!", CelOperator::LOGICAL_NOT}, + {"||", CelOperator::LOGICAL_OR}, + {"in", CelOperator::IN}, + }); + return c; + }(); return *operators_map; } const std::map& Operators() { - static std::shared_ptr> operators_map = [&]() { - auto c = std::make_shared>( - std::map{{CelOperator::ADD, "+"}, - {CelOperator::SUBTRACT, "-"}, - {CelOperator::MULTIPLY, "*"}, - {CelOperator::DIVIDE, "/"}, - {CelOperator::MODULO, "%"}, - {CelOperator::EQUALS, "=="}, - {CelOperator::NOT_EQUALS, "!="}, - {CelOperator::GREATER, ">"}, - {CelOperator::GREATER_EQUALS, ">="}, - {CelOperator::LESS, "<"}, - {CelOperator::LESS_EQUALS, "<="}, - {CelOperator::LOGICAL_AND, "&&"}, - {CelOperator::LOGICAL_NOT, "!"}, - {CelOperator::LOGICAL_OR, "||"}, - {CelOperator::IN, "in"}, - {CelOperator::IN_DEPRECATED, "in"}, - {CelOperator::NEGATE, "-"}}); - return c; - }(); + static std::shared_ptr> operators_map = + [&]() { + auto c = std::make_shared>( + std::map{ + {CelOperator::ADD, "+"}, + {CelOperator::SUBTRACT, "-"}, + {CelOperator::MULTIPLY, "*"}, + {CelOperator::DIVIDE, "/"}, + {CelOperator::MODULO, "%"}, + {CelOperator::EQUALS, "=="}, + {CelOperator::NOT_EQUALS, "!="}, + {CelOperator::GREATER, ">"}, + {CelOperator::GREATER_EQUALS, ">="}, + {CelOperator::LESS, "<"}, + {CelOperator::LESS_EQUALS, "<="}, + {CelOperator::LOGICAL_AND, "&&"}, + {CelOperator::LOGICAL_NOT, "!"}, + {CelOperator::LOGICAL_OR, "||"}, + {CelOperator::IN, "in"}, + {CelOperator::IN_DEPRECATED, "in"}, + {CelOperator::NEGATE, "-"}}); + return c; + }(); return *operators_map; } @@ -102,30 +108,30 @@ const std::map& Precedences() { auto c = std::make_shared>( std::map{{CelOperator::CONDITIONAL, 8}, - {CelOperator::LOGICAL_OR, 7}, + {CelOperator::LOGICAL_OR, 7}, - {CelOperator::LOGICAL_AND, 6}, + {CelOperator::LOGICAL_AND, 6}, - {CelOperator::EQUALS, 5}, - {CelOperator::GREATER, 5}, - {CelOperator::GREATER_EQUALS, 5}, - {CelOperator::IN, 5}, - {CelOperator::LESS, 5}, - {CelOperator::LESS_EQUALS, 5}, - {CelOperator::NOT_EQUALS, 5}, - {CelOperator::IN_DEPRECATED, 5}, + {CelOperator::EQUALS, 5}, + {CelOperator::GREATER, 5}, + {CelOperator::GREATER_EQUALS, 5}, + {CelOperator::IN, 5}, + {CelOperator::LESS, 5}, + {CelOperator::LESS_EQUALS, 5}, + {CelOperator::NOT_EQUALS, 5}, + {CelOperator::IN_DEPRECATED, 5}, - {CelOperator::ADD, 4}, - {CelOperator::SUBTRACT, 4}, + {CelOperator::ADD, 4}, + {CelOperator::SUBTRACT, 4}, - {CelOperator::DIVIDE, 3}, - {CelOperator::MODULO, 3}, - {CelOperator::MULTIPLY, 3}, + {CelOperator::DIVIDE, 3}, + {CelOperator::MODULO, 3}, + {CelOperator::MULTIPLY, 3}, - {CelOperator::LOGICAL_NOT, 2}, - {CelOperator::NEGATE, 2}, + {CelOperator::LOGICAL_NOT, 2}, + {CelOperator::NEGATE, 2}, - {CelOperator::INDEX, 1}}); + {CelOperator::INDEX, 1}}); return c; }(); return *precedence_map; diff --git a/common/type.cc b/common/type.cc index d0b02637f..ad9fa0ddc 100644 --- a/common/type.cc +++ b/common/type.cc @@ -79,7 +79,8 @@ absl::string_view UnrecognizedType::full_name() const { return absl::string_view(string_rep_).substr(6, string_rep_.size() - 8); } -Type::Type(const std::string& full_name) : data_(BasicType(BasicTypeValue::kNull)) { +Type::Type(const std::string& full_name) + : data_(BasicType(BasicTypeValue::kNull)) { auto itr = kBasicTypeMap->find(full_name); if (itr != kBasicTypeMap->end()) { data_ = itr->second; diff --git a/common/value.h b/common/value.h index 944be76e4..777021a4b 100644 --- a/common/value.h +++ b/common/value.h @@ -506,14 +506,18 @@ class Object : public Container { Value Value::FromDouble(double value) { return Create(value); } Value Value::FromString(absl::string_view value) { return Create(value); } -Value Value::FromString(const std::string& value) { return Create(value); } +Value Value::FromString(const std::string& value) { + return Create(value); +} Value Value::FromString(std::string&& value) { return Create(std::move(value)); } Value Value::FromBytes(absl::string_view value) { return Create(value); } -Value Value::FromBytes(const std::string& value) { return Create(value); } +Value Value::FromBytes(const std::string& value) { + return Create(value); +} Value Value::FromBytes(std::string&& value) { return Create(std::move(value)); } diff --git a/common/value_test.cc b/common/value_test.cc index 2c4def0d4..3a6a60c3a 100644 --- a/common/value_test.cc +++ b/common/value_test.cc @@ -118,7 +118,8 @@ struct ValueTestCase { return Value::Kind::kNull; } - static ValueTestCase ForInline(const Value& value, const std::string& debug_string, + static ValueTestCase ForInline(const Value& value, + const std::string& debug_string, const std::string& type_debug_string) { return ValueTestCase{value, value, value.kind(), value, debug_string, type_debug_string, diff --git a/conformance/BUILD b/conformance/BUILD index 2a402ceb0..2aacf821a 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -33,7 +33,7 @@ cc_binary( srcs = ["server.cc"], copts = ["-std=c++14"], deps = [ - "//base:status", + "//base:statusor", "//eval/eval:container_backed_list_impl", "//eval/eval:container_backed_map_impl", "//eval/public:builtin_func_registrar", @@ -59,14 +59,18 @@ cc_binary( "--check_server=$(location @com_google_cel_go//server/main:cel_server)", # Requires container support "--skip_test=basic/namespace/self_eval_container_lookup,self_eval_container_lookup_unchecked", + "--skip_test=basic/self_eval_nonzeroish/self_eval_bytes_invalid_utf8", # Requires heteregenous equality spec clarification "--skip_test=comparisons/eq_literal/eq_bytes", "--skip_test=comparisons/ne_literal/not_ne_bytes", "--skip_test=comparisons/in_list_literal/elem_in_mixed_type_list_error", "--skip_test=comparisons/in_map_literal/key_in_mixed_key_type_map_error", + "--skip_test=fields/in/singleton", # Requires qualified bindings error message relaxation "--skip_test=fields/qualified_identifier_resolution/ident_with_longest_prefix_check,int64_field_select_unsupported,list_field_select_unsupported,map_key_null,qualified_identifier_resolution_unchecked", "--skip_test=integer_math/int64_math/int64_overflow_positive,int64_overflow_negative,uint64_overflow_positive,uint64_overflow_negative", + "--skip_test=string/size/one_unicode,unicode", + "--skip_test=string/bytes_concat/left_unit", ] + ["$(location " + test + ")" for test in ALL_TESTS], data = [ ":server", @@ -93,4 +97,8 @@ sh_test( "@com_google_cel_go//server/main:cel_server", "@com_google_cel_spec//tests/simple:simple_test", ] + DASHBOARD_TESTS, + visibility = [ + "//:__subpackages__", + "//third_party/cel:__pkg__", + ], ) diff --git a/eval/compiler/BUILD b/eval/compiler/BUILD index d0636922f..93eaf8444 100644 --- a/eval/compiler/BUILD +++ b/eval/compiler/BUILD @@ -17,13 +17,14 @@ cc_library( copts = ["-std=c++14"], deps = [ ":constant_folding", - "//base:status", + "//base:status_macros", "//eval/eval:comprehension_step", "//eval/eval:const_value_step", "//eval/eval:container_access_step", "//eval/eval:create_list_step", "//eval/eval:create_struct_step", "//eval/eval:evaluator_core", + "//eval/eval:expression_build_warning", "//eval/eval:function_step", "//eval/eval:ident_step", "//eval/eval:jump_step", @@ -47,9 +48,44 @@ cc_test( copts = ["-std=c++14"], deps = [ ":flat_expr_builder", + "//base:status_macros", "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_builtins", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "//eval/testutil:test_message_cc_proto", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "flat_expr_builder_comprehensions_test", + srcs = [ + "flat_expr_builder_comprehensions_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":flat_expr_builder", + "//base:status_macros", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_builtins", + "//eval/public:cel_expression", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", @@ -86,6 +122,7 @@ cc_test( copts = ["-std=c++14"], deps = [ ":constant_folding", + "//base:status_macros", "//eval/public:builtin_func_registrar", "//eval/public:cel_function_registry", "//eval/testutil:test_message_cc_proto", diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index f9ccce95d..970ba7107 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -150,8 +150,8 @@ class ConstantFoldingTransform { } case Expr::kStructExpr: { auto struct_expr = out->mutable_struct_expr(); + struct_expr->set_message_name(expr.struct_expr().message_name()); int entries_size = expr.struct_expr().entries_size(); - bool all_constant = true; for (int i = 0; i < entries_size; i++) { auto& entry = expr.struct_expr().entries(i); auto new_entry = struct_expr->add_entries(); @@ -161,16 +161,13 @@ class ConstantFoldingTransform { new_entry->set_field_key(entry.field_key()); break; case Expr::CreateStruct::Entry::kMapKey: - all_constant = - Transform(entry.map_key(), new_entry->mutable_map_key()) && - all_constant; + Transform(entry.map_key(), new_entry->mutable_map_key()); break; default: GOOGLE_LOG(ERROR) << "Unsupported Entry kind: " << entry.key_kind_case(); break; } - all_constant = Transform(entry.value(), new_entry->mutable_value()) && - all_constant; + Transform(entry.value(), new_entry->mutable_value()); } return false; } diff --git a/eval/compiler/constant_folding_test.cc b/eval/compiler/constant_folding_test.cc index facfd7986..f7e24e7e4 100644 --- a/eval/compiler/constant_folding_test.cc +++ b/eval/compiler/constant_folding_test.cc @@ -8,6 +8,8 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" #include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" + namespace google { namespace api { namespace expr { @@ -43,6 +45,62 @@ TEST(ConstantFoldingTest, Select) { EXPECT_TRUE(idents.empty()); } +// Validate struct message creation +TEST(ConstantFoldingTest, StructMessage) { + Expr expr; + // {"field1": "y", "field2": "t"} + google::protobuf::TextFormat::ParseFromString( + R"pb( + id: 5 + struct_expr { + entries { + id: 11 + field_key: "field1" + value { const_expr { string_value: "value1" } } + } + entries { + id: 7 + field_key: "field2" + value { const_expr { int64_value: 12 } } + } + message_name: "MyProto" + })pb", + &expr); + + google::protobuf::Arena arena; + CelFunctionRegistry registry; + + absl::flat_hash_map idents; + Expr out; + FoldConstants(expr, registry, &arena, idents, &out); + + Expr expected; + google::protobuf::TextFormat::ParseFromString(R"( + id: 5 + struct_expr { + entries { + id: 11 + field_key: "field1" + value { ident_expr { name: "$v0" } } + } + entries { + id: 7 + field_key: "field2" + value { ident_expr { name: "$v1" } } + } + message_name: "MyProto" + })", + &expected); + google::protobuf::util::MessageDifferencer md; + EXPECT_TRUE(md.Compare(out, expected)) << out.DebugString(); + + EXPECT_EQ(idents.size(), 2); + EXPECT_TRUE(idents["$v0"].IsString()); + EXPECT_EQ(idents["$v0"].StringOrDie().value(), "value1"); + EXPECT_TRUE(idents["$v1"].IsInt64()); + EXPECT_EQ(idents["$v1"].Int64OrDie(), 12); +} + // Validate struct creation is not folded but recursed into TEST(ConstantFoldingTest, StructComprehension) { Expr expr; @@ -151,7 +209,7 @@ TEST(ConstantFoldingTest, LogicApplication) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; @@ -184,7 +242,7 @@ TEST(ConstantFoldingTest, FunctionApplication) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; @@ -218,7 +276,7 @@ TEST(ConstantFoldingTest, FunctionApplicationWithReceiver) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; @@ -251,7 +309,7 @@ TEST(ConstantFoldingTest, FunctionApplicationNoOverload) { google::protobuf::Arena arena; CelFunctionRegistry registry; - ASSERT_TRUE(RegisterBuiltinFunctions(®istry).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry)); absl::flat_hash_map idents; Expr out; diff --git a/eval/compiler/flat_expr_builder.cc b/eval/compiler/flat_expr_builder.cc index a856afc89..d2cd46209 100644 --- a/eval/compiler/flat_expr_builder.cc +++ b/eval/compiler/flat_expr_builder.cc @@ -11,6 +11,7 @@ #include "eval/eval/create_list_step.h" #include "eval/eval/create_struct_step.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" #include "eval/eval/function_step.h" #include "eval/eval/ident_step.h" #include "eval/eval/jump_step.h" @@ -20,7 +21,6 @@ #include "eval/public/ast_visitor.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_registry.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -49,14 +49,19 @@ class FlatExprVisitor : public AstVisitor { const std::set& enums, absl::string_view container, const absl::flat_hash_map& constant_idents, - bool enable_comprehension) + bool enable_comprehension, BuilderWarnings* warnings, + std::set* iter_variable_names) : flattened_path_(path), - progress_status_(cel_base::OkStatus()), + progress_status_(absl::OkStatus()), resolved_select_expr_(nullptr), function_registry_(function_registry), shortcircuiting_(shortcircuiting), constant_idents_(constant_idents), - enable_comprehension_(enable_comprehension) { + enable_comprehension_(enable_comprehension), + builder_warnings_(warnings), + iter_variable_names_(iter_variable_names) { + GOOGLE_CHECK(iter_variable_names_); + auto container_elements = absl::StrSplit(container, '.'); // Build list of prefixes from container. Non-empty prefixes must end with @@ -121,7 +126,7 @@ class FlatExprVisitor : public AstVisitor { if (value.has_value()) { AddStep(CreateConstValueStep(value.value(), expr->id())); } else { - SetProgressStatusError(cel_base::Status(cel_base::StatusCode::kInvalidArgument, + SetProgressStatusError(absl::Status(absl::StatusCode::kInvalidArgument, "Unsupported constant type")); } } @@ -166,7 +171,7 @@ class FlatExprVisitor : public AstVisitor { if (resolved_select_expr_) { if (!resolved_select_expr_->has_select_expr()) { - progress_status_ = cel_base::InternalError("Unexpected Expr type"); + progress_status_ = absl::InternalError("Unexpected Expr type"); return; } AddStep(CreateConstValueStep(value_desc, resolved_select_expr_->id())); @@ -201,9 +206,9 @@ class FlatExprVisitor : public AstVisitor { } // Check if we are "in the middle" of namespaced name. - // This is currently enum specific. Constant expression that corresponds to - // resolved enum value has been already created, thus preceding chain of - // selects is no longer relevant. + // This is currently enum specific. Constant expression that corresponds + // to resolved enum value has been already created, thus preceding chain + // of selects is no longer relevant. if (resolved_select_expr_) { if (expr == resolved_select_expr_) { resolved_select_expr_ = nullptr; @@ -267,7 +272,8 @@ class FlatExprVisitor : public AstVisitor { return; } // For regular functions, just create one based on registry. - AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_)); + AddStep(CreateFunctionStep(call_expr, expr->id(), *function_registry_, + builder_warnings_)); } } @@ -277,7 +283,7 @@ class FlatExprVisitor : public AstVisitor { return; } if (!enable_comprehension_) { - SetProgressStatusError(cel_base::Status(cel_base::StatusCode::kInvalidArgument, + SetProgressStatusError(absl::Status(absl::StatusCode::kInvalidArgument, "Comprehension support is disabled")); } cond_visitor_stack_.emplace(expr, @@ -287,7 +293,8 @@ class FlatExprVisitor : public AstVisitor { } // Invoked after all child nodes are processed. - void PostVisitComprehension(const Comprehension*, const Expr* expr, + void PostVisitComprehension(const Comprehension* comprehension_expr, + const Expr* expr, const SourcePosition*) override { if (!progress_status_.ok()) { return; @@ -295,6 +302,15 @@ class FlatExprVisitor : public AstVisitor { auto cond_visitor = FindCondVisitor(expr); cond_visitor->PostVisit(expr); cond_visitor_stack_.pop(); + + // Save off the names of the variables we're using, such that we have a + // full set of the names from the entire evaluation tree at the end. + if (!comprehension_expr->accu_var().empty()) { + iter_variable_names_->insert(comprehension_expr->accu_var()); + } + if (!comprehension_expr->iter_var().empty()) { + iter_variable_names_->insert(comprehension_expr->iter_var()); + } } // Invoked after each argument node processed. @@ -331,7 +347,7 @@ class FlatExprVisitor : public AstVisitor { AddStep(CreateCreateStructStep(struct_expr, expr->id())); } - cel_base::Status progress_status() const { return progress_status_; } + absl::Status progress_status() const { return progress_status_; } private: class CondVisitor { @@ -363,6 +379,8 @@ class FlatExprVisitor : public AstVisitor { }; // Visitor managing the "?" operation. + // TODO(issues/41) Make sure Unknowns are properly supported by ternary + // operation. class TernaryCondVisitor : public CondVisitor { public: explicit TernaryCondVisitor(FlatExprVisitor* visitor) @@ -382,10 +400,7 @@ class FlatExprVisitor : public AstVisitor { class ComprehensionVisitor : public CondVisitor { public: explicit ComprehensionVisitor(FlatExprVisitor* visitor) - : CondVisitor(visitor) - , next_step_(nullptr) - , cond_step_(nullptr) - {} + : CondVisitor(visitor), next_step_(nullptr), cond_step_(nullptr) {} void PreVisit(const Expr* expr) override; void PostVisitArg(int arg_num, const Expr* expr) override; @@ -414,7 +429,7 @@ class FlatExprVisitor : public AstVisitor { } } - void SetProgressStatusError(const cel_base::Status& status) { + void SetProgressStatusError(const absl::Status& status) { if (progress_status_.ok() && !status.ok()) { progress_status_ = status; } @@ -433,13 +448,13 @@ class FlatExprVisitor : public AstVisitor { } ExecutionPath* flattened_path_; - cel_base::Status progress_status_; + absl::Status progress_status_; std::stack>> cond_visitor_stack_; - // Maps effective namespace names to Expr objects (IDENTs/SELECTs) that define - // scopes for those namespaces. + // Maps effective namespace names to Expr objects (IDENTs/SELECTs) that + // define scopes for those namespaces. std::unordered_map namespace_map_; // Tracks SELECT-...SELECT-IDENT chains. std::deque> namespace_stack_; @@ -458,6 +473,10 @@ class FlatExprVisitor : public AstVisitor { const absl::flat_hash_map& constant_idents_; bool enable_comprehension_; + + BuilderWarnings* builder_warnings_; + + std::set* iter_variable_names_; }; void FlatExprVisitor::BinaryCondVisitor::PreVisit(const Expr*) {} @@ -529,7 +548,7 @@ void FlatExprVisitor::TernaryCondVisitor::PostVisitArg(int arg_num, if (jump_to_second_.exists()) { jump_to_second_.set_target(visitor_->GetCurrentIndex()); } else { - visitor_->SetProgressStatusError(cel_base::InvalidArgumentError( + visitor_->SetProgressStatusError(absl::InvalidArgumentError( "Error configuring ternary operator: jump_to_second_ is null")); } } @@ -543,14 +562,14 @@ void FlatExprVisitor::TernaryCondVisitor::PostVisit(const Expr*) { if (error_jump_.exists()) { error_jump_.set_target(visitor_->GetCurrentIndex()); } else { - visitor_->SetProgressStatusError(cel_base::InvalidArgumentError( + visitor_->SetProgressStatusError(absl::InvalidArgumentError( "Error configuring ternary operator: error_jump_ is null")); return; } if (jump_after_first_.exists()) { jump_after_first_.set_target(visitor_->GetCurrentIndex()); } else { - visitor_->SetProgressStatusError(cel_base::InvalidArgumentError( + visitor_->SetProgressStatusError(absl::InvalidArgumentError( "Error configuring ternary operator: jump_after_first_ is null")); return; } @@ -623,12 +642,10 @@ void FlatExprVisitor::ComprehensionVisitor::PostVisitArg(int arg_num, visitor_->AddStep(std::move(jump_to_next)); } // Set offsets. - cond_step_->set_jump_offset( - visitor_->GetCurrentIndex() - cond_step_pos_ - 1 - ); - next_step_->set_jump_offset( - visitor_->GetCurrentIndex() - next_step_pos_ - 1 - ); + cond_step_->set_jump_offset(visitor_->GetCurrentIndex() - cond_step_pos_ - + 1); + next_step_->set_jump_offset(visitor_->GetCurrentIndex() - next_step_pos_ - + 1); break; } case RESULT: { @@ -647,11 +664,13 @@ void FlatExprVisitor::ComprehensionVisitor::PostVisit(const Expr*) {} cel_base::StatusOr> FlatExprBuilder::CreateExpression(const Expr* expr, - const SourceInfo* source_info) const { + const SourceInfo* source_info, + std::vector* warnings) const { ExecutionPath execution_path; + BuilderWarnings warnings_builder(fail_on_warnings_); if (absl::StartsWith(container(), ".") || absl::EndsWith(container(), ".")) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( absl::StrCat("Invalid expression container:", container())); } @@ -663,9 +682,11 @@ FlatExprBuilder::CreateExpression(const Expr* expr, FoldConstants(*expr, *this->GetRegistry(), constant_arena_, idents, &out); } + std::set iter_variable_names; FlatExprVisitor visitor(this->GetRegistry(), &execution_path, shortcircuiting_, resolvable_enums(), container(), - idents, enable_comprehension_); + idents, enable_comprehension_, &warnings_builder, + &iter_variable_names); AstTraverse(constant_folding_ ? &out : expr, source_info, &visitor); @@ -674,12 +695,23 @@ FlatExprBuilder::CreateExpression(const Expr* expr, } std::unique_ptr expression_impl = - absl::make_unique(expr, std::move(execution_path), - comprehension_max_iterations_); + absl::make_unique( + expr, std::move(execution_path), comprehension_max_iterations_, + std::move(iter_variable_names), enable_unknowns_, + enable_unknown_function_results_); + if (warnings != nullptr) { + *warnings = std::move(warnings_builder).warnings(); + } return std::move(expression_impl); } +cel_base::StatusOr> +FlatExprBuilder::CreateExpression(const Expr* expr, + const SourceInfo* source_info) const { + return CreateExpression(expr, source_info, nullptr); +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 48beaee18..62dfb84b0 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -1,8 +1,8 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ #define THIRD_PARTY_CEL_CPP_EVAL_COMPILER_FLAT_EXPR_BUILDER_H_ -#include "eval/public/cel_expression.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "eval/public/cel_expression.h" namespace google { namespace api { @@ -14,11 +14,23 @@ namespace runtime { class FlatExprBuilder : public CelExpressionBuilder { public: FlatExprBuilder() - : shortcircuiting_(true), + : enable_unknowns_(false), + enable_unknown_function_results_(false), + shortcircuiting_(true), constant_folding_(false), constant_arena_(nullptr), enable_comprehension_(true), - comprehension_max_iterations_(0) {} + comprehension_max_iterations_(0), + fail_on_warnings_(true) {} + + // set_enable_unknowns controls support for unknowns in expressions created. + void set_enable_unknowns(bool enabled) { enable_unknowns_ = enabled; } + + // set_enable_unknown_function_results controls support for unknown function + // results. + void set_enable_unknown_function_results(bool enabled) { + enable_unknown_function_results_ = enabled; + } // set_shortcircuiting regulates shortcircuiting of some expressions. // Be default shortcircuiting is enabled. @@ -39,17 +51,30 @@ class FlatExprBuilder : public CelExpressionBuilder { comprehension_max_iterations_ = max_iterations; } + // Warnings (e.g. no function bound) fail immediately. + void set_fail_on_warnings(bool should_fail) { + fail_on_warnings_ = should_fail; + } + cel_base::StatusOr> CreateExpression( const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const override; + cel_base::StatusOr> CreateExpression( + const google::api::expr::v1alpha1::Expr* expr, + const google::api::expr::v1alpha1::SourceInfo* source_info, + std::vector* warnings) const override; + private: + bool enable_unknowns_; + bool enable_unknown_function_results_; bool shortcircuiting_; bool constant_folding_; google::protobuf::Arena* constant_arena_; bool enable_comprehension_; int comprehension_max_iterations_; + bool fail_on_warnings_; }; } // namespace runtime diff --git a/eval/compiler/flat_expr_builder_comprehensions_test.cc b/eval/compiler/flat_expr_builder_comprehensions_test.cc new file mode 100644 index 000000000..331821853 --- /dev/null +++ b/eval/compiler/flat_expr_builder_comprehensions_test.cc @@ -0,0 +1,180 @@ +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/field_mask.pb.h" +#include "google/protobuf/text_format.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_builtins.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" +#include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::api::expr::v1alpha1::Expr; +using google::api::expr::v1alpha1::SourceInfo; + +// [1, 2].filter(x, [3, 4].all(y, x < y)) +const char kNestedComprehension[] = R"pb( + id: 27 + comprehension_expr { + iter_var: "x" + iter_range { + id: 1 + list_expr { + elements { + id: 2 + const_expr { int64_value: 1 } + } + elements { + id: 3 + const_expr { int64_value: 2 } + } + } + } + accu_var: "__result__" + accu_init { + id: 22 + list_expr {} + } + loop_condition { + id: 23 + const_expr { bool_value: true } + } + loop_step { + id: 26 + call_expr { + function: "_?_:_" + args { + id: 20 + comprehension_expr { + iter_var: "y" + iter_range { + id: 6 + list_expr { + elements { + id: 7 + const_expr { int64_value: 3 } + } + elements { + id: 8 + const_expr { int64_value: 4 } + } + } + } + accu_var: "__result__" + accu_init { + id: 14 + const_expr { bool_value: true } + } + loop_condition { + id: 16 + call_expr { + function: "@not_strictly_false" + args { + id: 15 + ident_expr { name: "__result__" } + } + } + } + loop_step { + id: 18 + call_expr { + function: "_&&_" + args { + id: 17 + ident_expr { name: "__result__" } + } + args { + id: 12 + call_expr { + function: "_<_" + args { + id: 11 + ident_expr { name: "x" } + } + args { + id: 13 + ident_expr { name: "y" } + } + } + } + } + } + result { + id: 19 + ident_expr { name: "__result__" } + } + } + } + args { + id: 25 + call_expr { + function: "_+_" + args { + id: 21 + ident_expr { name: "__result__" } + } + args { + id: 24 + list_expr { + elements { + id: 5 + ident_expr { name: "x" } + } + } + } + } + } + args { + id: 21 + ident_expr { name: "__result__" } + } + } + } + result { + id: 21 + ident_expr { name: "__result__" } + } + })pb"; + +TEST(FlatExprBuilderComprehensionsTest, NestedComp) { + FlatExprBuilder builder; + Expr expr; + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kNestedComprehension, &expr)); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); + SourceInfo source_info; + auto build_status = builder.CreateExpression(&expr, &source_info); + ASSERT_OK(build_status); + + auto cel_expr = std::move(build_status.ValueOrDie()); + + Activation activation; + google::protobuf::Arena arena; + auto result_or = cel_expr->Evaluate(activation, &arena); + ASSERT_OK(result_or); + CelValue result = result_or.ValueOrDie(); + ASSERT_TRUE(result.IsList()); + EXPECT_THAT(*result.ListOrDie(), testing::SizeIs(2)); +} + +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/compiler/flat_expr_builder_test.cc b/eval/compiler/flat_expr_builder_test.cc index 2b01bb22f..d9ddfd23a 100644 --- a/eval/compiler/flat_expr_builder_test.cc +++ b/eval/compiler/flat_expr_builder_test.cc @@ -5,10 +5,20 @@ #include "google/protobuf/text_format.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/status/status.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_builtins.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" + namespace google { namespace api { namespace expr { @@ -32,10 +42,10 @@ class ConcatFunction : public CelFunction { "concat", false, {CelValue::Type::kString, CelValue::Type::kString}}; } - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 2) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } @@ -47,7 +57,7 @@ class ConcatFunction : public CelFunction { *result = CelValue::CreateString(concatenated); - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -67,10 +77,10 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { auto register_status = builder.GetRegistry()->Register(absl::make_unique()); - ASSERT_TRUE(register_status.ok()); + ASSERT_OK(register_status); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -82,7 +92,7 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { google::protobuf::Arena arena; auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); @@ -91,20 +101,66 @@ TEST(FlatExprBuilderTest, SimpleEndToEnd) { EXPECT_THAT(result.StringOrDie().value(), Eq("prefixtest")); } +TEST(FlatExprBuilderTest, DelayedFunctionResolutionErrors) { + Expr expr; + SourceInfo source_info; + auto call_expr = expr.mutable_call_expr(); + call_expr->set_function("concat"); + + auto arg1 = call_expr->add_args(); + arg1->mutable_const_expr()->set_string_value("prefix"); + + auto arg2 = call_expr->add_args(); + arg2->mutable_ident_expr()->set_name("value"); + + FlatExprBuilder builder; + builder.set_fail_on_warnings(false); + std::vector warnings; + + // Concat function not registered. + + auto build_status = builder.CreateExpression(&expr, &source_info, &warnings); + ASSERT_OK(build_status); + + auto cel_expr = std::move(build_status.ValueOrDie()); + + std::string variable = "test"; + + Activation activation; + activation.InsertValue("value", CelValue::CreateString(&variable)); + + google::protobuf::Arena arena; + + auto eval_status = cel_expr->Evaluate(activation, &arena); + ASSERT_OK(eval_status); + + CelValue result = eval_status.ValueOrDie(); + + ASSERT_TRUE(result.IsError()); + + EXPECT_THAT(result.ErrorOrDie()->message(), + Eq("No matching overloads found")); + + ASSERT_THAT(warnings, testing::SizeIs(1)); + EXPECT_EQ(warnings[0].code(), absl::StatusCode::kInvalidArgument); + EXPECT_THAT(warnings[0].message(), + testing::HasSubstr("No overloads provided")); +} + class RecorderFunction : public CelFunction { public: explicit RecorderFunction(const std::string& name, int* count) : CelFunction(CelFunctionDescriptor{name, false, {}}), count_(count) {} - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (!args.empty()) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } (*count_)++; *result = CelValue::CreateBool(true); - return cel_base::OkStatus(); + return absl::OkStatus(); } int* count_; @@ -130,21 +186,21 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { auto register_status1 = builder.GetRegistry()->Register( absl::make_unique("recorder1", &count1)); - ASSERT_TRUE(register_status1.ok()); + ASSERT_OK(register_status1); auto register_status2 = builder.GetRegistry()->Register( absl::make_unique("recorder2", &count2)); - ASSERT_TRUE(register_status2.ok()); + ASSERT_OK(register_status2); // Shortcircuiting on. auto build_status_on = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_on.ok()); + ASSERT_OK(build_status_on); auto cel_expr_on = std::move(build_status_on.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto eval_status_on = cel_expr_on->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_on.ok()); + ASSERT_OK(eval_status_on); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(0)); @@ -152,7 +208,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { // Shortcircuiting off. builder.set_shortcircuiting(false); auto build_status_off = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_off.ok()); + ASSERT_OK(build_status_off); auto cel_expr_off = std::move(build_status_off.ValueOrDie()); @@ -160,7 +216,7 @@ TEST(FlatExprBuilderTest, Shortcircuiting) { count2 = 0; auto eval_status_off = cel_expr_off->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_off.ok()); + ASSERT_OK(eval_status_off); EXPECT_THAT(count1, Eq(1)); EXPECT_THAT(count2, Eq(1)); @@ -193,32 +249,32 @@ TEST(FlatExprBuilderTest, ShortcircuitingComprehension) { int count = 0; auto register_status = builder.GetRegistry()->Register( absl::make_unique("loop_step", &count)); - ASSERT_TRUE(register_status.ok()); + ASSERT_OK(register_status); // Shortcircuiting on. auto build_status_on = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_on.ok()); + ASSERT_OK(build_status_on); auto cel_expr_on = std::move(build_status_on.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto eval_status_on = cel_expr_on->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_on.ok()); + ASSERT_OK(eval_status_on); EXPECT_THAT(count, Eq(0)); // Shortcircuiting off. builder.set_shortcircuiting(false); auto build_status_off = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status_off.ok()); + ASSERT_OK(build_status_off); auto cel_expr_off = std::move(build_status_off.ValueOrDie()); count = 0; auto eval_status_off = cel_expr_off->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status_off.ok()); + ASSERT_OK(eval_status_off); EXPECT_THAT(count, Eq(3)); } @@ -266,17 +322,17 @@ TEST(FlatExprBuilderTest, MapComprehension) { &expr); FlatExprBuilder builder; - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(result_or.ok()); + ASSERT_OK(result_or); CelValue result = result_or.ValueOrDie(); ASSERT_TRUE(result.IsBool()); EXPECT_TRUE(result.BoolOrDie()); @@ -353,17 +409,17 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForError) { &expr); FlatExprBuilder builder; - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(result_or.ok()); + ASSERT_OK(result_or); CelValue result = result_or.ValueOrDie(); ASSERT_TRUE(result.IsError()); } @@ -428,17 +484,17 @@ TEST(FlatExprBuilderTest, ComprehensionWorksForNonContainer) { &expr); FlatExprBuilder builder; - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); Activation activation; google::protobuf::Arena arena; auto result_or = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(result_or.ok()); + ASSERT_OK(result_or); CelValue result = result_or.ValueOrDie(); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie()->message(), @@ -483,10 +539,10 @@ TEST(FlatExprBuilderTest, ComprehensionBudget) { FlatExprBuilder builder; builder.set_comprehension_max_iterations(1); - ASSERT_TRUE(RegisterBuiltinFunctions(builder.GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder.GetRegistry())); SourceInfo source_info; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -517,7 +573,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { FlatExprBuilder builder; auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -529,7 +585,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsInt64()); @@ -539,7 +595,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { mask.add_paths("message.message_value.int32_value"); activation.set_unknown_paths(mask); eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); ASSERT_TRUE(IsUnknownValueError(result)); @@ -550,7 +606,7 @@ TEST(FlatExprBuilderTest, UnknownSupportTest) { mask.add_paths("message.message_value"); activation.set_unknown_paths(mask); eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); ASSERT_TRUE(IsUnknownValueError(result)); @@ -579,10 +635,10 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.addResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -590,7 +646,7 @@ TEST(FlatExprBuilderTest, SimpleEnumTest) { Activation activation; auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); ASSERT_TRUE(result.IsInt64()); @@ -607,13 +663,13 @@ TEST(FlatExprBuilderTest, ContainerStringFormat) { builder.set_container(""); { auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); } builder.set_container("random.namespace"); { auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); } // Leading '.' @@ -650,19 +706,19 @@ void EvalExpressionWithEnum(absl::string_view enum_name, cur_expr->mutable_ident_expr()->set_name(enum_name_parts[0]); FlatExprBuilder builder; - builder.addResolvableEnum(TestMessage::TestEnum_descriptor()); - builder.addResolvableEnum(TestEnum_descriptor()); + builder.AddResolvableEnum(TestMessage::TestEnum_descriptor()); + builder.AddResolvableEnum(TestEnum_descriptor()); builder.set_container(std::string(container)); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); google::protobuf::Arena arena; Activation activation; auto eval_status = cel_expr->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); *result = eval_status.ValueOrDie(); } @@ -720,6 +776,164 @@ TEST(FlatExprBuilderTest, PartialQualifiedEnumResolution) { EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } +absl::Status RunTernaryExpression(CelValue selector, CelValue value1, + CelValue value2, google::protobuf::Arena* arena, + CelValue* result) { + Expr expr; + SourceInfo source_info; + auto call_expr = expr.mutable_call_expr(); + call_expr->set_function(builtin::kTernary); + + auto arg0 = call_expr->add_args(); + arg0->mutable_ident_expr()->set_name("selector"); + auto arg1 = call_expr->add_args(); + arg1->mutable_ident_expr()->set_name("value1"); + auto arg2 = call_expr->add_args(); + arg2->mutable_ident_expr()->set_name("value2"); + + FlatExprBuilder builder; + auto build_status = builder.CreateExpression(&expr, &source_info); + if (!build_status.ok()) { + return build_status.status(); + } + + auto cel_expr = std::move(build_status.ValueOrDie()); + + std::string variable = "test"; + + Activation activation; + activation.InsertValue("selector", selector); + activation.InsertValue("value1", value1); + activation.InsertValue("value2", value2); + + auto eval_status = cel_expr->Evaluate(activation, arena); + if (!eval_status.ok()) { + return eval_status.status(); + } + + *result = eval_status.ValueOrDie(); + return eval_status.status(); +} + +TEST(FlatExprBuilderTest, Ternary) { + Expr expr; + SourceInfo source_info; + auto call_expr = expr.mutable_call_expr(); + call_expr->set_function(builtin::kTernary); + + auto arg0 = call_expr->add_args(); + arg0->mutable_ident_expr()->set_name("selector"); + auto arg1 = call_expr->add_args(); + arg1->mutable_ident_expr()->set_name("value1"); + auto arg2 = call_expr->add_args(); + arg2->mutable_ident_expr()->set_name("value1"); + + FlatExprBuilder builder; + // builder.set_enable_unknowns(true); + auto build_status = builder.CreateExpression(&expr, &source_info); + ASSERT_OK(build_status); + + auto cel_expr = std::move(build_status.ValueOrDie()); + + google::protobuf::Arena arena; + + // On True, value 1 + { + CelValue result; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(1)); + + // Unknown handling + UnknownSet unknown_set; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(true), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + + ASSERT_OK(RunTernaryExpression( + CelValue::CreateBool(true), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(1)); + } + + // On False, value 2 + { + CelValue result; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(2)); + + // Unknown handling + UnknownSet unknown_set; + ASSERT_OK(RunTernaryExpression(CelValue::CreateBool(false), + CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsInt64()); + EXPECT_THAT(result.Int64OrDie(), Eq(2)); + + ASSERT_OK(RunTernaryExpression( + CelValue::CreateBool(false), CelValue::CreateInt64(1), + CelValue::CreateUnknownSet(&unknown_set), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + } + // On Error, surface error + { + CelValue result; + ASSERT_OK(RunTernaryExpression(CreateErrorValue(&arena, "error"), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsError()); + } + // On Unknown, surface Unknown + { + UnknownSet unknown_set; + CelValue result; + ASSERT_OK(RunTernaryExpression(CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + EXPECT_THAT(&unknown_set, Eq(result.UnknownSetOrDie())); + } + // We should not merge unknowns + { + Expr selector; + selector.mutable_ident_expr()->set_name("selector"); + CelAttribute selector_attr(selector, {}); + + Expr value1; + value1.mutable_ident_expr()->set_name("value1"); + CelAttribute value1_attr(value1, {}); + + Expr value2; + value2.mutable_ident_expr()->set_name("value2"); + CelAttribute value2_attr(value2, {}); + + UnknownSet unknown_selector(UnknownAttributeSet({&selector_attr})); + UnknownSet unknown_value1(UnknownAttributeSet({&value1_attr})); + UnknownSet unknown_value2(UnknownAttributeSet({&value2_attr})); + CelValue result; + ASSERT_OK(RunTernaryExpression( + CelValue::CreateUnknownSet(&unknown_selector), + CelValue::CreateUnknownSet(&unknown_value1), + CelValue::CreateUnknownSet(&unknown_value2), &arena, &result)); + ASSERT_TRUE(result.IsUnknownSet()); + const UnknownSet* result_set = result.UnknownSetOrDie(); + EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(1)); + EXPECT_THAT(result_set->unknown_attributes() + .attributes()[0] + ->variable() + .ident_expr() + .name(), + Eq("selector")); + } +} + } // namespace } // namespace runtime diff --git a/eval/eval/BUILD b/eval/eval/BUILD index b66cd3c5b..f86f160ef 100644 --- a/eval/eval/BUILD +++ b/eval/eval/BUILD @@ -16,11 +16,19 @@ cc_library( ], copts = ["-std=c++14"], deps = [ + ":attribute_trail", + ":unknowns_utility", + "//base:status_macros", + "//base:statusor", "//eval/public:activation", + "//eval/public:cel_attribute", "//eval/public:cel_expression", "//eval/public:cel_value", - "@com_google_absl//absl/strings", + "//eval/public:unknown_attribute_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -74,9 +82,10 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", - "//base:status", "//eval/public:activation", "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -96,8 +105,9 @@ cc_library( ":expression_step_base", "//eval/public:activation", "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -112,14 +122,18 @@ cc_library( copts = ["-std=c++14"], deps = [ ":evaluator_core", + ":expression_build_warning", ":expression_step_base", - "//base:status", + "//base:status_macros", "//eval/public:activation", "//eval/public:cel_function", "//eval/public:cel_function_provider", "//eval/public:cel_function_registry", "//eval/public:cel_value", + "//eval/public:unknown_function_result_set", + "//eval/public:unknown_set", "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", ], ) @@ -133,9 +147,9 @@ cc_library( ], copts = ["-std=c++14"], deps = [ - "//base:status", "//eval/public:cel_value", "//internal:proto_util", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -223,7 +237,6 @@ cc_library( "//eval/public:activation", "//eval/public:cel_value", "@com_google_absl//absl/strings", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) @@ -262,7 +275,7 @@ cc_library( ":evaluator_core", ":expression_step_base", ":field_access", - "//base:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -303,6 +316,7 @@ cc_library( "//eval/public:activation", "//eval/public:cel_function", "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -320,6 +334,7 @@ cc_library( deps = [ ":evaluator_core", ":expression_step_base", + "//base:status_macros", "//eval/public:activation", "//eval/public:cel_function", "//eval/public:cel_value", @@ -337,8 +352,11 @@ cc_test( copts = ["-std=c++14"], deps = [ ":evaluator_core", + "//base:status_macros", "//eval/compiler:flat_expr_builder", "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_value", "@com_github_google_googletest//:gtest_main", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -354,6 +372,7 @@ cc_test( deps = [ ":const_value_step", ":evaluator_core", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -371,6 +390,8 @@ cc_test( ":container_backed_list_impl", ":container_backed_map_impl", ":ident_step", + "//base:status_macros", + "//eval/public:cel_attribute", "//eval/public:cel_builtins", "//eval/public:cel_value", "@com_github_google_googletest//:gtest_main", @@ -388,6 +409,7 @@ cc_test( deps = [ ":evaluator_core", ":ident_step", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], @@ -402,9 +424,16 @@ cc_test( copts = ["-std=c++14"], deps = [ ":evaluator_core", + ":expression_build_warning", ":function_step", - "//base:status", + ":ident_step", + "//base:status_macros", + "//eval/public:cel_attribute", "//eval/public:cel_function", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_function_result_set", + "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -459,6 +488,23 @@ cc_test( ], ) +cc_test( + name = "logic_step_test", + size = "small", + srcs = [ + "logic_step_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":ident_step", + ":logic_step", + "//base:status_macros", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "@com_github_google_googletest//:gtest_main", + ], +) + cc_test( name = "select_step_test", size = "small", @@ -468,9 +514,11 @@ cc_test( copts = ["-std=c++14"], deps = [ ":container_backed_map_impl", - ":evaluator_core", ":ident_step", ":select_step", + "//base:status_macros", + "//eval/public:cel_attribute", + "//eval/public:unknown_attribute_set", "//eval/testutil:test_message_cc_proto", "//testutil:util", "@com_github_google_googletest//:gtest_main", @@ -488,10 +536,13 @@ cc_test( deps = [ ":const_value_step", ":create_list_step", - ":evaluator_core", - "//eval/testutil:test_message_cc_proto", + ":ident_step", + "//base:status_macros", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:unknown_attribute_set", "@com_github_google_googletest//:gtest_main", - "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_absl//absl/strings", ], ) @@ -507,6 +558,7 @@ cc_test( ":container_backed_map_impl", ":create_struct_step", ":ident_step", + "//base:status_macros", "//eval/testutil:test_message_cc_proto", "//testutil:util", "@com_github_google_googletest//:gtest_main", @@ -514,3 +566,105 @@ cc_test( "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", ], ) + +cc_library( + name = "expression_build_warning", + srcs = [ + "expression_build_warning.cc", + ], + hdrs = [ + "expression_build_warning.h", + ], + copts = ["-std=c++14"], + deps = [ + "@com_google_absl//absl/status", + ], +) + +cc_test( + name = "expression_build_warning_test", + size = "small", + srcs = [ + "expression_build_warning_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":expression_build_warning", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/status", + ], +) + +cc_library( + name = "attribute_trail", + srcs = ["attribute_trail.cc"], + hdrs = ["attribute_trail.h"], + copts = ["-std=c++14"], + deps = [ + "//base:statusor", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:cel_expression", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "attribute_trail_test", + size = "small", + srcs = [ + "attribute_trail_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":attribute_trail", + "//eval/public:cel_attribute", + "//eval/public:cel_value", + "@com_github_google_googletest//:gtest_main", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + ], +) + +cc_library( + name = "unknowns_utility", + srcs = ["unknowns_utility.cc"], + hdrs = ["unknowns_utility.h"], + copts = ["-std=c++14"], + deps = [ + ":attribute_trail", + "//base:statusor", + "//eval/public:activation", + "//eval/public:cel_attribute", + "//eval/public:cel_expression", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:optional", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "unknowns_utility_test", + size = "small", + srcs = [ + "unknowns_utility_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":unknowns_utility", + "//eval/public:cel_attribute", + "//eval/public:cel_value", + "//eval/public:unknown_attribute_set", + "//eval/public:unknown_set", + "@com_github_google_googletest//:gtest_main", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + ], +) diff --git a/eval/eval/attribute_trail.cc b/eval/eval/attribute_trail.cc new file mode 100644 index 000000000..ec07728ee --- /dev/null +++ b/eval/eval/attribute_trail.cc @@ -0,0 +1,25 @@ +#include "eval/eval/attribute_trail.h" + +#include "absl/status/status.h" +#include "eval/public/cel_value.h" +#include "base/statusor.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +// Creates AttributeTrail with attribute path incremented by "qualifier". +AttributeTrail AttributeTrail::Step(CelAttributeQualifier qualifier, + google::protobuf::Arena* arena) const { + // Cannot continue void trail + if (empty()) return AttributeTrail(); + + std::vector qualifiers = attribute_->qualifier_path(); + qualifiers.push_back(qualifier); + return AttributeTrail(google::protobuf::Arena::Create( + arena, attribute_->variable(), std::move(qualifiers))); +} +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/attribute_trail.h b/eval/eval/attribute_trail.h new file mode 100644 index 000000000..eb18ac0d5 --- /dev/null +++ b/eval/eval/attribute_trail.h @@ -0,0 +1,63 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ + +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "absl/types/optional.h" +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// AttributeTrail reflects current attribute path. +// It is functionally similar to CelAttribute, yet intended to have better +// complexity on attribute path increment operations. +// TODO(issues/41) Current AttributeTrail implementation is equivalent to +// CelAttribute - improve it. +// Intended to be used in conjunction with CelValue, describing the attribute +// value originated from. +// Empty AttributeTrail denotes object with attribute path not defined +// or supported. +class AttributeTrail { + public: + AttributeTrail() : attribute_(nullptr) {} + AttributeTrail(google::api::expr::v1alpha1::Expr root, google::protobuf::Arena* arena) + : AttributeTrail(google::protobuf::Arena::Create( + arena, root, std::vector())) {} + + // Creates AttributeTrail with attribute path incremented by "qualifier". + AttributeTrail Step(CelAttributeQualifier qualifier, + google::protobuf::Arena* arena) const; + + // Creates AttributeTrail with attribute path incremented by "qualifier". + AttributeTrail Step(const std::string* qualifier, + google::protobuf::Arena* arena) const { + return Step( + CelAttributeQualifier::Create(CelValue::CreateString(qualifier)), + arena); + } + + // Returns CelAttribute that corresponds to content of AttributeTrail. + const CelAttribute* attribute() const { return attribute_; } + + bool empty() const { return !attribute_; } + + private: + AttributeTrail(const CelAttribute* attribute) : attribute_(attribute) {} + const CelAttribute* attribute_; +}; +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_ATTRIBUTE_TRAIL_H_ diff --git a/eval/eval/attribute_trail_test.cc b/eval/eval/attribute_trail_test.cc new file mode 100644 index 000000000..ccf0d36c8 --- /dev/null +++ b/eval/eval/attribute_trail_test.cc @@ -0,0 +1,41 @@ +#include "eval/eval/attribute_trail.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Attribute Trail behavior +TEST(AttributeTrailTest, AttributeTrailEmptyStep) { + google::protobuf::Arena arena; + std::string step = "step"; + CelValue step_value = CelValue::CreateString(&step); + AttributeTrail trail; + ASSERT_TRUE(trail.Step(&step, &arena).empty()); + ASSERT_TRUE( + trail.Step(CelAttributeQualifier::Create(step_value), &arena).empty()); +} + +TEST(AttributeTrailTest, AttributeTrailStep) { + google::protobuf::Arena arena; + std::string step = "step"; + CelValue step_value = CelValue::CreateString(&step); + google::api::expr::v1alpha1::Expr root; + root.mutable_ident_expr()->set_name("ident"); + AttributeTrail trail = AttributeTrail(root, &arena).Step(&step, &arena); + + ASSERT_TRUE(trail.attribute() != nullptr); + ASSERT_EQ(*trail.attribute(), + CelAttribute(root, {CelAttributeQualifier::Create(step_value)})); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/comprehension_step.cc b/eval/eval/comprehension_step.cc index e835149cf..afe88984a 100644 --- a/eval/eval/comprehension_step.cc +++ b/eval/eval/comprehension_step.cc @@ -1,5 +1,7 @@ #include "eval/eval/comprehension_step.h" + #include "absl/strings/str_cat.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -66,7 +68,7 @@ void ComprehensionNextStep::set_error_jump_offset(int offset) { // // Stack on error: // 0. error -cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { enum { POS_PREVIOUS_LOOP_STEP, POS_ITER_RANGE, @@ -75,7 +77,7 @@ cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { POS_LOOP_STEP, }; if (!frame->value_stack().HasEnough(5)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } auto state = frame->value_stack().GetSpan(5); CelValue iter_range = state[POS_ITER_RANGE]; @@ -95,19 +97,22 @@ cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { "ComprehensionNextStep: want int64_t, got ", CelValue::TypeName(current_index_value.type()) ); - return cel_base::Status(cel_base::StatusCode::kInternal, message); + return absl::Status(absl::StatusCode::kInternal, message); } auto increment_status = frame->IncrementIterations(); if (!increment_status.ok()) { return increment_status; } int64_t current_index = current_index_value.Int64OrDie(); + if (current_index == -1) { + RETURN_IF_ERROR(frame->PushIterFrame()); + } CelValue loop_step = state[POS_LOOP_STEP]; frame->value_stack().Pop(5); frame->value_stack().Push(loop_step); - frame->iter_vars()[accu_var_] = loop_step; + RETURN_IF_ERROR(frame->SetIterVar(accu_var_, loop_step)); if (current_index >= cel_list->size() - 1) { - frame->iter_vars().erase(iter_var_); + RETURN_IF_ERROR(frame->ClearIterVar(iter_var_)); return frame->JumpTo(jump_offset_); } frame->value_stack().Push(iter_range); @@ -115,8 +120,8 @@ cel_base::Status ComprehensionNextStep::Evaluate(ExecutionFrame* frame) const { CelValue current_value = (*cel_list)[current_index]; frame->value_stack().Push(CelValue::CreateInt64(current_index)); frame->value_stack().Push(current_value); - frame->iter_vars()[iter_var_] = current_value; - return cel_base::OkStatus(); + RETURN_IF_ERROR(frame->SetIterVar(iter_var_, current_value)); + return absl::OkStatus(); } ComprehensionCondStep::ComprehensionCondStep(const std::string&, @@ -136,9 +141,9 @@ void ComprehensionCondStep::set_jump_offset(int offset) { // Stack size before: 5. // Stack size after: 4. // Stack size on break: 1. -cel_base::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(5)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue loop_condition_value = frame->value_stack().Peek(); if (!loop_condition_value.IsBool()) { @@ -146,16 +151,16 @@ cel_base::Status ComprehensionCondStep::Evaluate(ExecutionFrame* frame) const { "ComprehensionCondStep:: want bool, got ", CelValue::TypeName(loop_condition_value.type()) ); - return cel_base::Status(cel_base::StatusCode::kInternal, message); + return absl::Status(absl::StatusCode::kInternal, message); } bool loop_condition = loop_condition_value.BoolOrDie(); frame->value_stack().Pop(1); // loop_condition if (!loop_condition && shortcircuiting_) { frame->value_stack().Pop(3); // current_value, current_index, iter_range - frame->iter_vars().erase(iter_var_); + RETURN_IF_ERROR(frame->ClearIterVar(iter_var_)); return frame->JumpTo(jump_offset_); } - return cel_base::OkStatus(); + return absl::OkStatus(); } ComprehensionFinish::ComprehensionFinish(const std::string& accu_var, const std::string&, @@ -166,38 +171,38 @@ ComprehensionFinish::ComprehensionFinish(const std::string& accu_var, const std: // // Stack size before: 2. // Stack size after: 1. -cel_base::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { +absl::Status ComprehensionFinish::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(2)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue result = frame->value_stack().Peek(); frame->value_stack().Pop(1); // result frame->value_stack().PopAndPush(result); - frame->iter_vars().erase(accu_var_); - return cel_base::OkStatus(); + RETURN_IF_ERROR(frame->PopIterFrame()); + return absl::OkStatus(); } class ListKeysStep : public ExpressionStepBase { public: ListKeysStep(int64_t expr_id) : ExpressionStepBase(expr_id, false) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; }; std::unique_ptr CreateListKeysStep(int64_t expr_id) { return absl::make_unique(expr_id); } -cel_base::Status ListKeysStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ListKeysStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue map_value = frame->value_stack().Peek(); if (map_value.IsMap()) { const CelMap* cel_map = map_value.MapOrDie(); frame->value_stack().PopAndPush(CelValue::CreateList(cel_map->ListKeys())); - return cel_base::OkStatus(); + return absl::OkStatus(); } - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/eval/comprehension_step.h b/eval/eval/comprehension_step.h index 9a0c3146d..1fa9ee6bf 100644 --- a/eval/eval/comprehension_step.h +++ b/eval/eval/comprehension_step.h @@ -21,7 +21,7 @@ class ComprehensionNextStep : public ExpressionStepBase { void set_jump_offset(int offset); void set_error_jump_offset(int offset); - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string accu_var_; @@ -37,7 +37,7 @@ class ComprehensionCondStep : public ExpressionStepBase { void set_jump_offset(int offset); - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string iter_var_; @@ -50,7 +50,7 @@ class ComprehensionFinish : public ExpressionStepBase { ComprehensionFinish(const std::string& accu_var, const std::string& iter_var, int64_t expr_id); - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: std::string accu_var_; diff --git a/eval/eval/const_value_step.cc b/eval/eval/const_value_step.cc index 465e40132..ee070a302 100644 --- a/eval/eval/const_value_step.cc +++ b/eval/eval/const_value_step.cc @@ -18,16 +18,16 @@ class ConstValueStep : public ExpressionStepBase { ConstValueStep(const CelValue& value, int64_t expr_id, bool comes_from_ast) : ExpressionStepBase(expr_id, comes_from_ast), value_(value) {} - cel_base::Status Evaluate(ExecutionFrame* context) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: CelValue value_; }; -cel_base::Status ConstValueStep::Evaluate(ExecutionFrame* frame) const { +absl::Status ConstValueStep::Evaluate(ExecutionFrame* frame) const { frame->value_stack().Push(value_); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/const_value_step_test.cc b/eval/eval/const_value_step_test.cc index d2da6a773..51acef959 100644 --- a/eval/eval/const_value_step_test.cc +++ b/eval/eval/const_value_step_test.cc @@ -1,9 +1,10 @@ #include "eval/eval/const_value_step.h" -#include "eval/eval/evaluator_core.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/eval/evaluator_core.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -14,8 +15,8 @@ namespace { using testing::Eq; -using google::api::expr::v1alpha1::Expr; using google::api::expr::v1alpha1::Constant; +using google::api::expr::v1alpha1::Expr; using google::protobuf::Arena; @@ -31,7 +32,7 @@ cel_base::StatusOr RunConstantExpression(const Expr* expr, google::api::expr::v1alpha1::Expr dummy_expr; - CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}); Activation activation; @@ -47,7 +48,7 @@ TEST(ConstValueStepTest, TestEvaluationConstInt64) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -64,7 +65,7 @@ TEST(ConstValueStepTest, TestEvaluationConstUint64) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -81,7 +82,7 @@ TEST(ConstValueStepTest, TestEvaluationConstBool) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -98,7 +99,7 @@ TEST(ConstValueStepTest, TestEvaluationConstNull) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -114,7 +115,7 @@ TEST(ConstValueStepTest, TestEvaluationConstString) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -131,7 +132,7 @@ TEST(ConstValueStepTest, TestEvaluationConstDouble) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -150,7 +151,7 @@ TEST(ConstValueStepTest, TestEvaluationConstBytes) { auto status = RunConstantExpression(&expr, const_expr, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto value = status.ValueOrDie(); diff --git a/eval/eval/container_access_step.cc b/eval/eval/container_access_step.cc index 15246a18a..37e4ba731 100644 --- a/eval/eval/container_access_step.cc +++ b/eval/eval/container_access_step.cc @@ -1,10 +1,12 @@ #include "eval/eval/container_access_step.h" #include "google/protobuf/arena.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/public/cel_value.h" -#include "base/status.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -13,17 +15,20 @@ namespace runtime { namespace { +constexpr int NUM_CONTAINER_ACCESS_ARGUMENTS = 2; + // ContainerAccessStep performs message field access specified by Expr::Select // message. class ContainerAccessStep : public ExpressionStepBase { public: ContainerAccessStep(int64_t expr_id) : ExpressionStepBase(expr_id) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - CelValue PerformLookup(const CelValue& container, const CelValue& key, - google::protobuf::Arena* arena) const; + using ValueAttributePair = std::pair; + + ValueAttributePair PerformLookup(ExecutionFrame* frame) const; CelValue LookupInMap(const CelMap* cel_map, const CelValue& key, google::protobuf::Arena* arena) const; CelValue LookupInList(const CelList* cel_list, const CelValue& key, @@ -72,52 +77,78 @@ inline CelValue ContainerAccessStep::LookupInList(const CelList* cel_list, } } -CelValue ContainerAccessStep::PerformLookup(const CelValue& container, - const CelValue& key, - google::protobuf::Arena* arena) const { - if (container.IsError()) { - return container; +ContainerAccessStep::ValueAttributePair ContainerAccessStep::PerformLookup( + ExecutionFrame* frame) const { + auto input_args = + frame->value_stack().GetSpan(NUM_CONTAINER_ACCESS_ARGUMENTS); + AttributeTrail trail; + + const CelValue& container = input_args[0]; + const CelValue& key = input_args[1]; + + if (frame->enable_unknowns()) { + auto unknown_set = + frame->unknowns_utility().MergeUnknowns(input_args, nullptr); + + if (unknown_set) { + return {CelValue::CreateUnknownSet(unknown_set), trail}; + } + + // We guarantee that GetAttributeSpan can aquire this number of arguments + // by calling HasEnough() at the beginning of Execute() method. + auto input_attrs = + frame->value_stack().GetAttributeSpan(NUM_CONTAINER_ACCESS_ARGUMENTS); + auto container_trail = input_attrs[0]; + trail = container_trail.Step(CelAttributeQualifier::Create(key), + frame->arena()); + + if (frame->unknowns_utility().CheckForUnknown(trail, + /*use_partial=*/false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({trail.attribute()})); + + return {CelValue::CreateUnknownSet(unknown_set), trail}; + } } - if (key.IsError()) { - return key; + + for (const auto& value : input_args) { + if (value.IsError()) { + return {value, trail}; + } } + // Select steps can be applied to either maps or messages switch (container.type()) { case CelValue::Type::kMap: { const CelMap* cel_map = container.MapOrDie(); - return LookupInMap(cel_map, key, arena); + return {LookupInMap(cel_map, key, frame->arena()), trail}; } case CelValue::Type::kList: { const CelList* cel_list = container.ListOrDie(); - return LookupInList(cel_list, key, arena); + return {LookupInList(cel_list, key, frame->arena()), trail}; } default: { - return CreateErrorValue( - arena, absl::StrCat("Unexpected container type for [] operation: ", - CelValue::TypeName(key.type()))); + auto error = CreateErrorValue( + frame->arena(), + absl::StrCat("Unexpected container type for [] operation: ", + CelValue::TypeName(key.type()))); + return {error, trail}; } } } -cel_base::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { - const int NUM_ARGUMENTS = 2; - - if (!frame->value_stack().HasEnough(NUM_ARGUMENTS)) { - return cel_base::Status( - cel_base::StatusCode::kInternal, +absl::Status ContainerAccessStep::Evaluate(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(NUM_CONTAINER_ACCESS_ARGUMENTS)) { + return absl::Status( + absl::StatusCode::kInternal, "Insufficient arguments supplied for ContainerAccess-type expression"); } - auto input_args = frame->value_stack().GetSpan(NUM_ARGUMENTS); - - const CelValue& container = input_args[0]; - const CelValue& key = input_args[1]; - - CelValue result = PerformLookup(container, key, frame->arena()); - frame->value_stack().Pop(NUM_ARGUMENTS); - frame->value_stack().Push(result); + auto result = PerformLookup(frame); + frame->value_stack().Pop(NUM_CONTAINER_ACCESS_ARGUMENTS); + frame->value_stack().Push(result.first, result.second); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/container_access_step_test.cc b/eval/eval/container_access_step_test.cc index 22a13fd62..d3b507310 100644 --- a/eval/eval/container_access_step_test.cc +++ b/eval/eval/container_access_step_test.cc @@ -10,8 +10,10 @@ #include "eval/eval/container_backed_list_impl.h" #include "eval/eval/container_backed_map_impl.h" #include "eval/eval/ident_step.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_value.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -25,95 +27,140 @@ using ::google::protobuf::Struct; using google::api::expr::v1alpha1::Expr; using google::api::expr::v1alpha1::SourceInfo; -class ContainerAccessStepTest : public ::testing::Test { - protected: - ContainerAccessStepTest() {} +using TestParamType = std::tuple; - void SetUp() override {} +// Helper method. Looks up in registry and tests comparison operation. +CelValue EvaluateAttributeHelper( + google::protobuf::Arena* arena, CelValue container, CelValue key, bool receiver_style, + bool enable_unknown, const std::vector& patterns) { + ExecutionPath path; - // Helper method. Looks up in registry and tests comparison operation. - CelValue PerformRun(CelValue container, CelValue key, bool receiver_style) { - ExecutionPath path; + Expr expr; + SourceInfo source_info; + auto call = expr.mutable_call_expr(); + + call->set_function(builtin::kIndex); + + Expr* container_expr = + (receiver_style) ? call->mutable_target() : call->add_args(); + Expr* key_expr = call->add_args(); + + container_expr->mutable_ident_expr()->set_name("container"); + key_expr->mutable_ident_expr()->set_name("key"); - Expr expr; - SourceInfo source_info; - auto call = expr.mutable_call_expr(); + path.push_back(std::move( + CreateIdentStep(&container_expr->ident_expr(), 1).ValueOrDie())); + path.push_back( + std::move(CreateIdentStep(&key_expr->ident_expr(), 2).ValueOrDie())); + path.push_back(std::move(CreateContainerAccessStep(call, 3).ValueOrDie())); - call->set_function(builtin::kIndex); + CelExpressionFlatImpl cel_expr(&expr, std::move(path), 0, {}, enable_unknown); + Activation activation; - Expr* container_expr = - (receiver_style) ? call->mutable_target() : call->add_args(); - Expr* key_expr = call->add_args(); + activation.InsertValue("container", container); + activation.InsertValue("key", key); + + activation.set_unknown_attribute_patterns(patterns); + auto eval_status = cel_expr.Evaluate(activation, arena); + + EXPECT_OK(eval_status); + return eval_status.ValueOrDie(); +} + +class ContainerAccessStepTest : public ::testing::Test { + protected: + ContainerAccessStepTest() {} - container_expr->mutable_ident_expr()->set_name("container"); - key_expr->mutable_ident_expr()->set_name("key"); + void SetUp() override {} - path.push_back(std::move( - CreateIdentStep(&container_expr->ident_expr(), 1).ValueOrDie())); - path.push_back( - std::move(CreateIdentStep(&key_expr->ident_expr(), 2).ValueOrDie())); - path.push_back(std::move(CreateContainerAccessStep(call, 3).ValueOrDie())); + CelValue EvaluateAttribute( + CelValue container, CelValue key, bool receiver_style, + bool enable_unknown, + const std::vector& patterns = {}) { + return EvaluateAttributeHelper(&arena_, container, key, receiver_style, + enable_unknown, patterns); + } + google::protobuf::Arena arena_; +}; - CelExpressionFlatImpl cel_expr(&expr, std::move(path), 0); - Activation activation; +class ContainerAccessStepUniformityTest + : public ::testing::TestWithParam { + protected: + ContainerAccessStepUniformityTest() {} - activation.InsertValue("container", container); - activation.InsertValue("key", key); - auto eval_status = cel_expr.Evaluate(activation, &arena_); + void SetUp() override {} - EXPECT_TRUE(eval_status.ok()); - return eval_status.ValueOrDie(); + // Helper method. Looks up in registry and tests comparison operation. + CelValue EvaluateAttribute( + CelValue container, CelValue key, bool receiver_style, + bool enable_unknown, + const std::vector& patterns = {}) { + return EvaluateAttributeHelper(&arena_, container, key, receiver_style, + enable_unknown, patterns); } google::protobuf::Arena arena_; }; -TEST_F(ContainerAccessStepTest, TestListIndexAccess) { +TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccess) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - CelValue result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(1), true); + TestParamType param = GetParam(); + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), + std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsInt64()); ASSERT_EQ(result.Int64OrDie(), 2); } -TEST_F(ContainerAccessStepTest, TestListIndexAccessOutOfBounds) { +TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessOutOfBounds) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - CelValue result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(0), true); + TestParamType param = GetParam(); + + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(0), + std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsInt64()); - result = PerformRun(CelValue::CreateList(&cel_list), CelValue::CreateInt64(2), - true); + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(2), std::get<0>(param), + std::get<1>(param)); ASSERT_TRUE(result.IsInt64()); - result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateInt64(-1), true); + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(-1), std::get<0>(param), + std::get<1>(param)); ASSERT_TRUE(result.IsError()); - result = PerformRun(CelValue::CreateList(&cel_list), CelValue::CreateInt64(3), - true); + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(3), std::get<0>(param), + std::get<1>(param)); ASSERT_TRUE(result.IsError()); } -TEST_F(ContainerAccessStepTest, TestListIndexAccessNotAnInt) { +TEST_P(ContainerAccessStepUniformityTest, TestListIndexAccessNotAnInt) { ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), CelValue::CreateInt64(2), CelValue::CreateInt64(3)}); - CelValue result = PerformRun(CelValue::CreateList(&cel_list), - CelValue::CreateUint64(1), true); + TestParamType param = GetParam(); + + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateUint64(1), + std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsError()); } -TEST_F(ContainerAccessStepTest, TestMapKeyAccess) { +TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccess) { + TestParamType param = GetParam(); + const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; const std::string kKey2 = "testkey2"; @@ -122,25 +169,92 @@ TEST_F(ContainerAccessStepTest, TestMapKeyAccess) { (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); - CelValue result = PerformRun(CelValue::CreateMessage(&cel_struct, &arena_), - CelValue::CreateString(&kKey0), true); + CelValue result = EvaluateAttribute( + CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateString(&kKey0), std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsString()); ASSERT_EQ(result.StringOrDie().value(), "value0"); } -TEST_F(ContainerAccessStepTest, TestMapKeyAccessNotFound) { +TEST_P(ContainerAccessStepUniformityTest, TestMapKeyAccessNotFound) { + TestParamType param = GetParam(); + const std::string kKey0 = "testkey0"; const std::string kKey1 = "testkey1"; Struct cel_struct; (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); - CelValue result = PerformRun(CelValue::CreateMessage(&cel_struct, &arena_), - CelValue::CreateString(&kKey1), true); + CelValue result = EvaluateAttribute( + CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateString(&kKey1), std::get<0>(param), std::get<1>(param)); ASSERT_TRUE(result.IsError()); } +TEST_F(ContainerAccessStepTest, TestListIndexAccessUnknown) { + ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3)}); + + CelValue result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), true, true, {}); + + ASSERT_TRUE(result.IsInt64()); + ASSERT_EQ(result.Int64OrDie(), 2); + + std::vector patterns = {CelAttributePattern( + "container", + {CelAttributeQualifierPattern::Create(CelValue::CreateInt64(1))})}; + + result = EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateInt64(1), true, true, patterns); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST_F(ContainerAccessStepTest, TestListUnknownKey) { + ContainerBackedListImpl cel_list({CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3)}); + + UnknownSet unknown_set; + CelValue result = + EvaluateAttribute(CelValue::CreateList(&cel_list), + CelValue::CreateUnknownSet(&unknown_set), true, true); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST_F(ContainerAccessStepTest, TestMapUnknownKey) { + const std::string kKey0 = "testkey0"; + const std::string kKey1 = "testkey1"; + const std::string kKey2 = "testkey2"; + Struct cel_struct; + (*cel_struct.mutable_fields())[kKey0].set_string_value("value0"); + (*cel_struct.mutable_fields())[kKey1].set_string_value("value1"); + (*cel_struct.mutable_fields())[kKey2].set_string_value("value2"); + + UnknownSet unknown_set; + CelValue result = + EvaluateAttribute(CelValue::CreateMessage(&cel_struct, &arena_), + CelValue::CreateUnknownSet(&unknown_set), true, true); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +TEST_F(ContainerAccessStepTest, TestUnknownContainer) { + UnknownSet unknown_set; + CelValue result = EvaluateAttribute(CelValue::CreateUnknownSet(&unknown_set), + CelValue::CreateInt64(1), true, true); + + ASSERT_TRUE(result.IsUnknownSet()); +} + +INSTANTIATE_TEST_SUITE_P(CombinedContainerTest, + ContainerAccessStepUniformityTest, + testing::Combine(testing::Bool(), testing::Bool())); + } // namespace } // namespace runtime diff --git a/eval/eval/create_list_step.cc b/eval/eval/create_list_step.cc index d2b7fe66c..27cc6654e 100644 --- a/eval/eval/create_list_step.cc +++ b/eval/eval/create_list_step.cc @@ -1,4 +1,5 @@ #include "eval/eval/create_list_step.h" + #include "eval/eval/container_backed_list_impl.h" namespace google { @@ -13,32 +14,48 @@ class CreateListStep : public ExpressionStepBase { CreateListStep(int64_t expr_id, int list_size) : ExpressionStepBase(expr_id), list_size_(list_size) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: int list_size_; }; -cel_base::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { +absl::Status CreateListStep::Evaluate(ExecutionFrame* frame) const { if (list_size_ < 0) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "CreateListStep: list size is <0"); } if (!frame->value_stack().HasEnough(list_size_)) { - return cel_base::Status(cel_base::StatusCode::kInternal, - "CreateListStep: stack undeflow"); + return absl::Status(absl::StatusCode::kInternal, + "CreateListStep: stack underflow"); } auto args = frame->value_stack().GetSpan(list_size_); - CelList* cel_list = google::protobuf::Arena::Create( - frame->arena(), std::vector(args.begin(), args.end())); + CelValue result; + + const UnknownSet* unknown_set = nullptr; + if (frame->enable_unknowns()) { + unknown_set = frame->unknowns_utility().MergeUnknowns( + args, frame->value_stack().GetAttributeSpan(list_size_), + /*initial_set=*/nullptr, + /*use_partial=*/true); + if (unknown_set != nullptr) { + result = CelValue::CreateUnknownSet(unknown_set); + } + } + + if (unknown_set == nullptr) { + CelList* cel_list = google::protobuf::Arena::Create( + frame->arena(), std::vector(args.begin(), args.end())); + result = CelValue::CreateList(cel_list); + } frame->value_stack().Pop(list_size_); - frame->value_stack().Push(CelValue::CreateList(cel_list)); + frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/create_list_step_test.cc b/eval/eval/create_list_step_test.cc index fc796b264..cef1737a4 100644 --- a/eval/eval/create_list_step_test.cc +++ b/eval/eval/create_list_step_test.cc @@ -1,8 +1,14 @@ #include "eval/eval/create_list_step.h" -#include "eval/eval/const_value_step.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/strings/str_cat.h" +#include "eval/eval/const_value_step.h" +#include "eval/eval/ident_step.h" +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/unknown_attribute_set.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -17,7 +23,8 @@ using google::api::expr::v1alpha1::Expr; // Helper method. Creates simple pipeline containing Select step and runs it. cel_base::StatusOr RunExpression(const std::vector& values, - google::protobuf::Arena* arena) { + google::protobuf::Arena* arena, + bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; @@ -42,12 +49,54 @@ cel_base::StatusOr RunExpression(const std::vector& values, path.push_back(std::move(step0_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + enable_unknowns); Activation activation; return cel_expr.Evaluate(activation, arena); } +// Helper method. Creates simple pipeline containing Select step and runs it. +cel_base::StatusOr RunExpressionWithCelValues( + const std::vector& values, google::protobuf::Arena* arena, + bool enable_unknowns) { + ExecutionPath path; + Expr dummy_expr; + + Activation activation; + auto create_list = dummy_expr.mutable_list_expr(); + int ind = 0; + for (auto value : values) { + std::string var_name = absl::StrCat("name_", ind++); + auto expr0 = create_list->add_elements(); + expr0->set_id(ind); + expr0->mutable_ident_expr()->set_name(var_name); + + auto ident_step_status = CreateIdentStep(&expr0->ident_expr(), expr0->id()); + if (!ident_step_status.ok()) { + return ident_step_status.status(); + } + + path.push_back(std::move(ident_step_status.ValueOrDie())); + activation.InsertValue(var_name, value); + } + + auto step0_status = CreateCreateListStep(create_list, dummy_expr.id()); + + if (!step0_status.ok()) { + return step0_status.status(); + } + + path.push_back(std::move(step0_status.ValueOrDie())); + + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + enable_unknowns); + + return cel_expr.Evaluate(activation, arena); +} + +class CreateListStepTest : public testing::TestWithParam {}; + // Tests error when not enough list elements are on the stack during list // creation. TEST(CreateListStepTest, TestCreateListStackUndeflow) { @@ -60,11 +109,11 @@ TEST(CreateListStepTest, TestCreateListStackUndeflow) { auto step0_status = CreateCreateListStep(create_list, dummy_expr.id()); - ASSERT_TRUE(step0_status.ok()); + ASSERT_OK(step0_status); path.push_back(std::move(step0_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}); Activation activation; google::protobuf::Arena arena; @@ -73,36 +122,36 @@ TEST(CreateListStepTest, TestCreateListStackUndeflow) { ASSERT_FALSE(status.ok()); } -TEST(CreateListStepTest, CreateListEmpty) { +TEST_P(CreateListStepTest, CreateListEmpty) { google::protobuf::Arena arena; - auto eval_result = RunExpression({}, &arena); + auto eval_result = RunExpression({}, &arena, GetParam()); - ASSERT_TRUE(eval_result.ok()); + ASSERT_OK(eval_result); const CelValue result_value = eval_result.ValueOrDie(); ASSERT_TRUE(result_value.IsList()); EXPECT_THAT(result_value.ListOrDie()->size(), Eq(0)); } -TEST(CreateListStepTest, CreateListOne) { +TEST_P(CreateListStepTest, CreateListOne) { google::protobuf::Arena arena; - auto eval_result = RunExpression({100}, &arena); + auto eval_result = RunExpression({100}, &arena, GetParam()); - ASSERT_TRUE(eval_result.ok()); + ASSERT_OK(eval_result); const CelValue result_value = eval_result.ValueOrDie(); ASSERT_TRUE(result_value.IsList()); EXPECT_THAT(result_value.ListOrDie()->size(), Eq(1)); EXPECT_THAT((*result_value.ListOrDie())[0].Int64OrDie(), Eq(100)); } -TEST(CreateListStepTest, CreateListHundred) { +TEST_P(CreateListStepTest, CreateListHundred) { google::protobuf::Arena arena; std::vector values; for (size_t i = 0; i < 100; i++) { values.push_back(i); } - auto eval_result = RunExpression(values, &arena); + auto eval_result = RunExpression(values, &arena, GetParam()); - ASSERT_TRUE(eval_result.ok()); + ASSERT_OK(eval_result); const CelValue result_value = eval_result.ValueOrDie(); ASSERT_TRUE(result_value.IsList()); EXPECT_THAT(result_value.ListOrDie()->size(), @@ -112,6 +161,36 @@ TEST(CreateListStepTest, CreateListHundred) { } } +TEST(CreateListStepTest, CreateListHundredAnd2Unknowns) { + google::protobuf::Arena arena; + std::vector values; + + Expr expr0; + expr0.mutable_ident_expr()->set_name("name0"); + CelAttribute attr0(expr0, {}); + Expr expr1; + expr1.mutable_ident_expr()->set_name("name1"); + CelAttribute attr1(expr1, {}); + UnknownSet unknown_set0(UnknownAttributeSet({&attr0})); + UnknownSet unknown_set1(UnknownAttributeSet({&attr1})); + for (size_t i = 0; i < 100; i++) { + values.push_back(CelValue::CreateInt64(i)); + } + values.push_back(CelValue::CreateUnknownSet(&unknown_set0)); + values.push_back(CelValue::CreateUnknownSet(&unknown_set1)); + + auto eval_result = RunExpressionWithCelValues(values, &arena, true); + + ASSERT_OK(eval_result); + const CelValue result_value = eval_result.ValueOrDie(); + ASSERT_TRUE(result_value.IsUnknownSet()); + const UnknownSet* result_set = result_value.UnknownSetOrDie(); + EXPECT_THAT(result_set->unknown_attributes().attributes().size(), Eq(2)); +} + +INSTANTIATE_TEST_SUITE_P(CombinedCreateListTest, CreateListStepTest, + testing::Bool()); + } // namespace } // namespace runtime diff --git a/eval/eval/create_struct_step.cc b/eval/eval/create_struct_step.cc index c417b9a32..7bba18c50 100644 --- a/eval/eval/create_struct_step.cc +++ b/eval/eval/create_struct_step.cc @@ -1,10 +1,10 @@ #include "eval/eval/create_struct_step.h" -#include "eval/eval/container_backed_map_impl.h" -#include "eval/eval/field_access.h" #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/status/status.h" #include "absl/strings/substitute.h" -#include "base/canonical_errors.h" +#include "eval/eval/container_backed_map_impl.h" +#include "eval/eval/field_access.h" namespace google { namespace api { @@ -13,12 +13,12 @@ namespace runtime { namespace { -using ::google::protobuf::Message; -using ::google::protobuf::MessageFactory; using ::google::protobuf::Arena; using ::google::protobuf::Descriptor; using ::google::protobuf::DescriptorPool; using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::MessageFactory; class CreateStructStepForMessage : public ExpressionStepBase { public: @@ -32,10 +32,10 @@ class CreateStructStepForMessage : public ExpressionStepBase { descriptor_(descriptor), entries_(std::move(entries)) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; const Descriptor* descriptor_; std::vector entries_; @@ -46,20 +46,31 @@ class CreateStructStepForMap : public ExpressionStepBase { CreateStructStepForMap(int64_t expr_id, size_t entry_count) : ExpressionStepBase(expr_id), entry_count_(entry_count) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; size_t entry_count_; }; -::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, - CelValue* result) const { +absl::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, + CelValue* result) const { int entries_size = entries_.size(); absl::Span args = frame->value_stack().GetSpan(entries_size); + if (frame->enable_unknowns()) { + auto unknown_set = frame->unknowns_utility().MergeUnknowns( + args, frame->value_stack().GetAttributeSpan(entries_size), + /*initial_set=*/nullptr, + /*use_partial=*/true); + if (unknown_set != nullptr) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + } + const Message* prototype = MessageFactory::generated_factory()->GetPrototype(descriptor_); @@ -70,14 +81,14 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, *result = CreateErrorValue( frame->arena(), absl::Substitute("Failed to create message $0", descriptor_->name())); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } int index = 0; for (const auto& entry : entries_) { const CelValue& arg = args[index++]; - ::cel_base::Status status = ::cel_base::OkStatus(); + absl::Status status = absl::OkStatus(); if (entry.field->is_map()) { constexpr int kKeyField = 1; @@ -85,7 +96,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, const CelMap* cel_map; if (!arg.GetValue(&cel_map) || cel_map == nullptr) { - status = cel_base::InvalidArgumentError(absl::Substitute( + status = absl::InvalidArgumentError(absl::Substitute( "Failed to create message $0, field $1: value is not CelMap", descriptor_->name(), entry.field->name())); break; @@ -94,7 +105,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, auto entry_descriptor = entry.field->message_type(); if (entry_descriptor == nullptr) { - status = cel_base::InvalidArgumentError( + status = absl::InvalidArgumentError( absl::Substitute("Failed to create message $0, field $1: failed to " "find map entry descriptor", descriptor_->name(), entry.field->name())); @@ -107,14 +118,14 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, entry_descriptor->FindFieldByNumber(kValueField); if (key_field_descriptor == nullptr) { - status = cel_base::InvalidArgumentError( + status = absl::InvalidArgumentError( absl::Substitute("Failed to create message $0, field $1: failed to " "find key field descriptor", descriptor_->name(), entry.field->name())); break; } if (value_field_descriptor == nullptr) { - status = cel_base::InvalidArgumentError( + status = absl::InvalidArgumentError( absl::Substitute("Failed to create message $0, field $1: failed to " "find value field descriptor", descriptor_->name(), entry.field->name())); @@ -127,7 +138,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, auto value = (*cel_map)[key]; if (!value.has_value()) { - status = cel_base::InvalidArgumentError(absl::Substitute( + status = absl::InvalidArgumentError(absl::Substitute( "Failed to create message $0, field $1: Error serializing CelMap", descriptor_->name(), entry.field->name())); break; @@ -153,7 +164,7 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, absl::Substitute( "Failed to create message $0: value $1 is not CelList", descriptor_->name(), entry.field->name())); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } for (int i = 0; i < cel_list->size(); i++) { @@ -169,25 +180,24 @@ ::cel_base::Status CreateStructStepForMessage::DoEvaluate(ExecutionFrame* frame, frame->arena(), absl::Substitute("Failed to create message $0: reason $1", descriptor_->name(), status.ToString())); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } *result = CelValue::CreateMessage(msg, frame->arena()); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } -::cel_base::Status CreateStructStepForMessage::Evaluate( - ExecutionFrame* frame) const { +absl::Status CreateStructStepForMessage::Evaluate(ExecutionFrame* frame) const { if (frame->value_stack().size() < entries_.size()) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "CreateStructStepForMessage: stack undeflow"); } CelValue result; - ::cel_base::Status status = DoEvaluate(frame, &result); + absl::Status status = DoEvaluate(frame, &result); if (!status.ok()) { return status; } @@ -195,14 +205,24 @@ ::cel_base::Status CreateStructStepForMessage::Evaluate( frame->value_stack().Pop(entries_.size()); frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } -::cel_base::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, - CelValue* result) const { +absl::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, + CelValue* result) const { absl::Span args = frame->value_stack().GetSpan(2 * entry_count_); + if (frame->enable_unknowns()) { + const UnknownSet* unknown_set = frame->unknowns_utility().MergeUnknowns( + args, frame->value_stack().GetAttributeSpan(args.size()), + /*initial_set=*/nullptr, true); + if (unknown_set != nullptr) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + } + std::vector> map_entries; map_entries.reserve(entry_count_); for (size_t i = 0; i < entry_count_; i += 1) { @@ -216,7 +236,7 @@ ::cel_base::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, if (cel_map == nullptr) { *result = CreateErrorValue(frame->arena(), "Failed to create map"); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } *result = CelValue::CreateMap(cel_map.get()); @@ -224,18 +244,18 @@ ::cel_base::Status CreateStructStepForMap::DoEvaluate(ExecutionFrame* frame, // Pass object ownership to Arena. frame->arena()->Own(cel_map.release()); - return ::cel_base::OkStatus(); + return absl::OkStatus(); } -::cel_base::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { +absl::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const { if (frame->value_stack().size() < 2 * entry_count_) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "CreateStructStepForMap: stack undeflow"); } CelValue result; - ::cel_base::Status status = DoEvaluate(frame, &result); + absl::Status status = DoEvaluate(frame, &result); if (!status.ok()) { return status; } @@ -243,7 +263,7 @@ ::cel_base::Status CreateStructStepForMap::Evaluate(ExecutionFrame* frame) const frame->value_stack().Pop(2 * entry_count_); frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace @@ -260,20 +280,20 @@ cel_base::StatusOr> CreateCreateStructStep( create_struct_expr->message_name()); if (desc == nullptr) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Error configuring message creation: message descriptor not found"); } for (const auto& entry : create_struct_expr->entries()) { if (entry.field_key().empty()) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Error configuring message creation: field name missing"); } const FieldDescriptor* field_desc = desc->FindFieldByName(entry.field_key()); if (field_desc == nullptr) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Error configuring message creation: field name not found"); } entries.push_back({field_desc}); diff --git a/eval/eval/create_struct_step_test.cc b/eval/eval/create_struct_step_test.cc index c8fef5922..94749a6f3 100644 --- a/eval/eval/create_struct_step_test.cc +++ b/eval/eval/create_struct_step_test.cc @@ -9,6 +9,7 @@ #include "eval/eval/ident_step.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -16,12 +17,12 @@ namespace expr { namespace runtime { namespace { -using ::google::protobuf::Message; using ::google::protobuf::Arena; +using ::google::protobuf::Message; using testing::Eq; -using testing::Not; using testing::IsNull; +using testing::Not; using testing::Pointwise; using testutil::EqualsProto; @@ -32,7 +33,8 @@ using google::api::expr::v1alpha1::Expr; // builds message and runs it. cel_base::StatusOr RunExpression(absl::string_view field, const CelValue& value, - google::protobuf::Arena* arena) { + google::protobuf::Arena* arena, + bool enable_unknowns) { ExecutionPath path; Expr expr0; @@ -61,7 +63,8 @@ cel_base::StatusOr RunExpression(absl::string_view field, path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, + enable_unknowns); Activation activation; activation.InsertValue("message", value); @@ -69,9 +72,10 @@ cel_base::StatusOr RunExpression(absl::string_view field, } void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, - google::protobuf::Arena* arena, TestMessage* test_msg) { - auto status = RunExpression(field, value, arena); - ASSERT_TRUE(status.ok()); + google::protobuf::Arena* arena, TestMessage* test_msg, + bool enable_unknowns) { + auto status = RunExpression(field, value, arena, enable_unknowns); + ASSERT_OK(status); CelValue result = status.ValueOrDie(); ASSERT_TRUE(result.IsMessage()); @@ -85,13 +89,14 @@ void RunExpressionAndGetMessage(absl::string_view field, const CelValue& value, void RunExpressionAndGetMessage(absl::string_view field, std::vector values, - google::protobuf::Arena* arena, TestMessage* test_msg) { + google::protobuf::Arena* arena, TestMessage* test_msg, + bool enable_unknowns) { ContainerBackedListImpl cel_list(std::move(values)); CelValue value = CelValue::CreateList(&cel_list); - auto status = RunExpression(field, value, arena); - ASSERT_TRUE(status.ok()); + auto status = RunExpression(field, value, arena, enable_unknowns); + ASSERT_OK(status); CelValue result = status.ValueOrDie(); ASSERT_TRUE(result.IsMessage()); @@ -107,7 +112,7 @@ void RunExpressionAndGetMessage(absl::string_view field, // builds Map and runs it. cel_base::StatusOr RunCreateMapExpression( const std::vector> values, - google::protobuf::Arena* arena) { + google::protobuf::Arena* arena, bool enable_unknowns) { ExecutionPath path; Activation activation; @@ -158,11 +163,14 @@ cel_base::StatusOr RunCreateMapExpression( path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, + enable_unknowns); return cel_expr.Evaluate(activation, arena); } -TEST(CreateCreateStructStepTest, TestEmptyMessageCreation) { +class CreateCreateStructStepTest : public testing::TestWithParam {}; + +TEST_P(CreateCreateStructStepTest, TestEmptyMessageCreation) { ExecutionPath path; Expr expr1; @@ -172,17 +180,17 @@ TEST(CreateCreateStructStepTest, TestEmptyMessageCreation) { auto step_status = CreateCreateStructStep(create_struct, expr1.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); path.push_back(std::move(step_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&expr1, std::move(path), 0, {}, GetParam()); Activation activation; google::protobuf::Arena arena; auto status = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); CelValue result = status.ValueOrDie(); ASSERT_TRUE(result.IsMessage()); @@ -193,108 +201,126 @@ TEST(CreateCreateStructStepTest, TestEmptyMessageCreation) { ASSERT_EQ(msg->GetDescriptor(), TestMessage::descriptor()); } +// Test message creation if unknown argument is passed +TEST(CreateCreateStructStepTest, TestMessageCreateWithUnknown) { + Arena arena; + TestMessage test_msg; + UnknownSet unknown_set; + + auto eval_status = RunExpression( + "bool_value", CelValue::CreateUnknownSet(&unknown_set), &arena, true); + ASSERT_OK(eval_status); + ASSERT_TRUE(eval_status->IsUnknownSet()); +} + // Test that fields of type bool are set correctly -TEST(CreateCreateStructStepTest, TestSetBoolField) { +TEST_P(CreateCreateStructStepTest, TestSetBoolField) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bool_value", CelValue::CreateBool(true), &arena, &test_msg)); + "bool_value", CelValue::CreateBool(true), &arena, &test_msg, GetParam())); ASSERT_EQ(test_msg.bool_value(), true); } // Test that fields of type int32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetInt32Field) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int32_value", CelValue::CreateInt64(1), &arena, &test_msg)); + "int32_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); ASSERT_EQ(test_msg.int32_value(), 1); } // Test that fields of type uint32_t are set correctly. -TEST(CreateCreateStructStepTest, TestSetUInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetUInt32Field) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint32_value", CelValue::CreateUint64(1), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("uint32_value", CelValue::CreateUint64(1), + &arena, &test_msg, GetParam())); ASSERT_EQ(test_msg.uint32_value(), 1); } // Test that fields of type int64_t are set correctly. -TEST(CreateCreateStructStepTest, TestSetInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetInt64Field) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_value", CelValue::CreateInt64(1), &arena, &test_msg)); + "int64_value", CelValue::CreateInt64(1), &arena, &test_msg, GetParam())); EXPECT_EQ(test_msg.int64_value(), 1); } // Test that fields of type uint64_t are set correctly. -TEST(CreateCreateStructStepTest, TestSetUInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetUInt64Field) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_value", CelValue::CreateUint64(1), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("uint64_value", CelValue::CreateUint64(1), + &arena, &test_msg, GetParam())); EXPECT_EQ(test_msg.uint64_value(), 1); } // Test that fields of type float are set correctly -TEST(CreateCreateStructStepTest, TestSetFloatField) { +TEST_P(CreateCreateStructStepTest, TestSetFloatField) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "float_value", CelValue::CreateDouble(2.0), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("float_value", CelValue::CreateDouble(2.0), + &arena, &test_msg, GetParam())); EXPECT_DOUBLE_EQ(test_msg.float_value(), 2.0); } // Test that fields of type double are set correctly -TEST(CreateCreateStructStepTest, TestSetDoubleField) { +TEST_P(CreateCreateStructStepTest, TestSetDoubleField) { Arena arena; TestMessage test_msg; - ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "double_value", CelValue::CreateDouble(2.0), &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE( + RunExpressionAndGetMessage("double_value", CelValue::CreateDouble(2.0), + &arena, &test_msg, GetParam())); EXPECT_DOUBLE_EQ(test_msg.double_value(), 2.0); } // Test that fields of type string are set correctly. -TEST(CreateCreateStructStepTest, TestSetStringField) { +TEST_P(CreateCreateStructStepTest, TestSetStringField) { const std::string kTestStr = "test"; Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_value", CelValue::CreateString(&kTestStr), &arena, &test_msg)); + "string_value", CelValue::CreateString(&kTestStr), &arena, &test_msg, + GetParam())); EXPECT_EQ(test_msg.string_value(), kTestStr); } // Test that fields of type bytes are set correctly. -TEST(CreateCreateStructStepTest, TestSetBytesField) { +TEST_P(CreateCreateStructStepTest, TestSetBytesField) { Arena arena; const std::string kTestStr = "test"; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "bytes_value", CelValue::CreateBytes(&kTestStr), &arena, &test_msg)); + "bytes_value", CelValue::CreateBytes(&kTestStr), &arena, &test_msg, + GetParam())); EXPECT_EQ(test_msg.bytes_value(), kTestStr); } // Test that fields of type duration are set correctly. -TEST(CreateCreateStructStepTest, TestSetDurationField) { +TEST_P(CreateCreateStructStepTest, TestSetDurationField) { Arena arena; google::protobuf::Duration test_duration; @@ -304,12 +330,12 @@ TEST(CreateCreateStructStepTest, TestSetDurationField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "duration_value", CelValue::CreateDuration(&test_duration), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_THAT(test_msg.duration_value(), EqualsProto(test_duration)); } // Test that fields of type timestamp are set correctly. -TEST(CreateCreateStructStepTest, TestSetTimestampField) { +TEST_P(CreateCreateStructStepTest, TestSetTimestampField) { Arena arena; google::protobuf::Timestamp test_timestamp; @@ -319,12 +345,12 @@ TEST(CreateCreateStructStepTest, TestSetTimestampField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "timestamp_value", CelValue::CreateTimestamp(&test_timestamp), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_THAT(test_msg.timestamp_value(), EqualsProto(test_timestamp)); } // Test that fields of type Message are set correctly. -TEST(CreateCreateStructStepTest, TestSetMessageField) { +TEST_P(CreateCreateStructStepTest, TestSetMessageField) { Arena arena; // Create payload message and set some fields. @@ -336,23 +362,23 @@ TEST(CreateCreateStructStepTest, TestSetMessageField) { ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "message_value", CelValue::CreateMessage(&orig_msg, &arena), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_THAT(test_msg.message_value(), EqualsProto(orig_msg)); } // Test that fields of type Message are set correctly. -TEST(CreateCreateStructStepTest, TestSetEnumField) { +TEST_P(CreateCreateStructStepTest, TestSetEnumField) { Arena arena; TestMessage test_msg; ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( "enum_value", CelValue::CreateInt64(TestMessage::TEST_ENUM_2), &arena, - &test_msg)); + &test_msg, GetParam())); EXPECT_EQ(test_msg.enum_value(), TestMessage::TEST_ENUM_2); } // Test that fields of type bool are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedBoolField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedBoolField) { Arena arena; TestMessage test_msg; @@ -362,13 +388,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedBoolField) { values.push_back(CelValue::CreateBool(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("bool_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "bool_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.bool_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { Arena arena; TestMessage test_msg; @@ -378,13 +404,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedInt32Field) { values.push_back(CelValue::CreateInt64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("int32_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "int32_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.int32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { Arena arena; TestMessage test_msg; @@ -394,13 +420,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedUInt32Field) { values.push_back(CelValue::CreateUint64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint32_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "uint32_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.uint32_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type int64_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { Arena arena; TestMessage test_msg; @@ -410,13 +436,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedInt64Field) { values.push_back(CelValue::CreateInt64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("int64_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "int64_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.int64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint64_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { Arena arena; TestMessage test_msg; @@ -426,13 +452,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedUInt64Field) { values.push_back(CelValue::CreateUint64(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("uint64_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "uint64_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.uint64_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type float are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedFloatField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedFloatField) { Arena arena; TestMessage test_msg; @@ -442,13 +468,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedFloatField) { values.push_back(CelValue::CreateDouble(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("float_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "float_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.float_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type uint32_t are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { Arena arena; TestMessage test_msg; @@ -458,13 +484,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedDoubleField) { values.push_back(CelValue::CreateDouble(value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("double_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "double_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.double_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedStringField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedStringField) { Arena arena; TestMessage test_msg; @@ -474,13 +500,13 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedStringField) { values.push_back(CelValue::CreateString(&value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("string_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "string_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.string_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type String are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedBytesField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedBytesField) { Arena arena; TestMessage test_msg; @@ -490,14 +516,14 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedBytesField) { values.push_back(CelValue::CreateBytes(&value)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("bytes_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "bytes_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.bytes_list(), Pointwise(Eq(), kValues)); } // Test that repeated fields of type Message are set correctly -TEST(CreateCreateStructStepTest, TestSetRepeatedMessageField) { +TEST_P(CreateCreateStructStepTest, TestSetRepeatedMessageField) { Arena arena; TestMessage test_msg; @@ -509,15 +535,15 @@ TEST(CreateCreateStructStepTest, TestSetRepeatedMessageField) { values.push_back(CelValue::CreateMessage(&value, &arena)); } - ASSERT_NO_FATAL_FAILURE( - RunExpressionAndGetMessage("message_list", values, &arena, &test_msg)); + ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( + "message_list", values, &arena, &test_msg, GetParam())); ASSERT_THAT(test_msg.message_list()[0], EqualsProto(kValues[0])); ASSERT_THAT(test_msg.message_list()[1], EqualsProto(kValues[1])); } // Test that fields of type map are set correctly -TEST(CreateCreateStructStepTest, TestSetStringMapField) { +TEST_P(CreateCreateStructStepTest, TestSetStringMapField) { Arena arena; TestMessage test_msg; @@ -535,8 +561,8 @@ TEST(CreateCreateStructStepTest, TestSetStringMapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena, - &test_msg)); + "string_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, + GetParam())); ASSERT_EQ(test_msg.string_int32_map().size(), 2); ASSERT_EQ(test_msg.string_int32_map().at(kKeys[0]), 2); @@ -544,7 +570,7 @@ TEST(CreateCreateStructStepTest, TestSetStringMapField) { } // Test that fields of type map are set correctly -TEST(CreateCreateStructStepTest, TestSetInt64MapField) { +TEST_P(CreateCreateStructStepTest, TestSetInt64MapField) { Arena arena; TestMessage test_msg; @@ -562,8 +588,8 @@ TEST(CreateCreateStructStepTest, TestSetInt64MapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, - &test_msg)); + "int64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, + GetParam())); ASSERT_EQ(test_msg.int64_int32_map().size(), 2); ASSERT_EQ(test_msg.int64_int32_map().at(kKeys[0]), 1); @@ -571,7 +597,7 @@ TEST(CreateCreateStructStepTest, TestSetInt64MapField) { } // Test that fields of type map are set correctly -TEST(CreateCreateStructStepTest, TestSetUInt64MapField) { +TEST_P(CreateCreateStructStepTest, TestSetUInt64MapField) { Arena arena; TestMessage test_msg; @@ -589,8 +615,8 @@ TEST(CreateCreateStructStepTest, TestSetUInt64MapField) { entries.data(), entries.size())); ASSERT_NO_FATAL_FAILURE(RunExpressionAndGetMessage( - "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, - &test_msg)); + "uint64_int32_map", CelValue::CreateMap(cel_map.get()), &arena, &test_msg, + GetParam())); ASSERT_EQ(test_msg.uint64_int32_map().size(), 2); ASSERT_EQ(test_msg.uint64_int32_map().at(kKeys[0]), 1); @@ -598,11 +624,11 @@ TEST(CreateCreateStructStepTest, TestSetUInt64MapField) { } // Test that Empty Map is created successfully. -TEST(CreateCreateStructStepTest, TestCreateEmptyMap) { +TEST_P(CreateCreateStructStepTest, TestCreateEmptyMap) { Arena arena; - auto status = RunCreateMapExpression({}, &arena); + auto status = RunCreateMapExpression({}, &arena, GetParam()); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); CelValue result_value = status.ValueOrDie(); ASSERT_TRUE(result_value.IsMap()); @@ -611,8 +637,29 @@ TEST(CreateCreateStructStepTest, TestCreateEmptyMap) { ASSERT_EQ(cel_map->size(), 0); } +// Test message creation if unknown argument is passed +TEST(CreateCreateStructStepTest, TestMapCreateWithUnknown) { + Arena arena; + UnknownSet unknown_set; + std::vector> entries; + + std::vector kKeys = {"test2", "test1"}; + + entries.push_back( + {CelValue::CreateString(&kKeys[0]), CelValue::CreateInt64(2)}); + entries.push_back({CelValue::CreateString(&kKeys[1]), + CelValue::CreateUnknownSet(&unknown_set)}); + + auto status = RunCreateMapExpression(entries, &arena, true); + + ASSERT_OK(status); + + CelValue result_value = status.ValueOrDie(); + ASSERT_TRUE(result_value.IsUnknownSet()); +} + // Test that String Map is created successfully. -TEST(CreateCreateStructStepTest, TestCreateStringMap) { +TEST_P(CreateCreateStructStepTest, TestCreateStringMap) { Arena arena; std::vector> entries; @@ -624,9 +671,9 @@ TEST(CreateCreateStructStepTest, TestCreateStringMap) { entries.push_back( {CelValue::CreateString(&kKeys[1]), CelValue::CreateInt64(1)}); - auto status = RunCreateMapExpression(entries, &arena); + auto status = RunCreateMapExpression(entries, &arena, GetParam()); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); CelValue result_value = status.ValueOrDie(); ASSERT_TRUE(result_value.IsMap()); @@ -645,6 +692,9 @@ TEST(CreateCreateStructStepTest, TestCreateStringMap) { EXPECT_EQ(lookup1.value().Int64OrDie(), 1); } +INSTANTIATE_TEST_SUITE_P(CombinedCreateStructTest, CreateCreateStructStepTest, + testing::Bool()); + } // namespace } // namespace runtime diff --git a/eval/eval/evaluator_core.cc b/eval/eval/evaluator_core.cc index 5e0579aa4..29d601dd9 100644 --- a/eval/eval/evaluator_core.cc +++ b/eval/eval/evaluator_core.cc @@ -1,31 +1,147 @@ #include "eval/eval/evaluator_core.h" +#include "absl/status/status.h" +#include "absl/types/optional.h" +#include "eval/public/cel_value.h" +#include "base/status_macros.h" +#include "base/statusor.h" + namespace google { namespace api { namespace expr { namespace runtime { +namespace { + +absl::Status CheckIterAccess(CelExpressionFlatEvaluationState* state, + const std::string& name) { + if (state->iter_stack().empty()) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrCat( + "Attempted to update iteration variable outside of comprehension.'", + name, "'")); + } + auto iter = state->iter_variable_names().find(name); + if (iter == state->iter_variable_names().end()) { + return absl::Status( + absl::StatusCode::kInternal, + absl::StrCat("Attempted to set unknown variable '", name, "'")); + } + + return absl::OkStatus(); +} + +} // namespace + +void ValueStack::Clear() { + for (auto& v : stack_) { + v = CelValue(); + } + for (auto& attr : attribute_stack_) { + attr = AttributeTrail(); + } + + current_size_ = 0; +} -using google::api::expr::v1alpha1::Expr; +CelExpressionFlatEvaluationState::CelExpressionFlatEvaluationState( + size_t value_stack_size, const std::set& iter_variable_names, + google::protobuf::Arena* arena) + : value_stack_(value_stack_size), + iter_variable_names_(iter_variable_names), + arena_(arena) {} + +void CelExpressionFlatEvaluationState::Reset() { + iter_stack_.clear(); + value_stack_.Clear(); +} const ExpressionStep* ExecutionFrame::Next() { - size_t end_pos = execution_path_->size(); + size_t end_pos = execution_path_.size(); - if (pc_ < end_pos) return (*execution_path_)[pc_++].get(); + if (pc_ < end_pos) return execution_path_[pc_++].get(); if (pc_ > end_pos) { GOOGLE_LOG(ERROR) << "Attempting to step beyond the end of execution path."; } return nullptr; } +absl::Status ExecutionFrame::PushIterFrame() { + state_->iter_stack().push_back({}); + return absl::OkStatus(); +} + +absl::Status ExecutionFrame::PopIterFrame() { + if (state_->iter_stack().empty()) { + return absl::InternalError("Loop stack underflow."); + } + state_->iter_stack().pop_back(); + return absl::OkStatus(); +} + +absl::Status ExecutionFrame::SetIterVar(const std::string& name, + const CelValue& val) { + RETURN_IF_ERROR(CheckIterAccess(state_, name)); + state_->IterStackTop()[name] = val; + + return absl::OkStatus(); +} + +absl::Status ExecutionFrame::ClearIterVar(const std::string& name) { + RETURN_IF_ERROR(CheckIterAccess(state_, name)); + state_->IterStackTop()[name] = absl::nullopt; + return absl::OkStatus(); +} + +bool ExecutionFrame::GetIterVar(const std::string& name, CelValue* val) const { + absl::Status status = CheckIterAccess(state_, name); + if (!status.ok()) { + return false; + } + + for (auto iter = state_->iter_stack().rbegin(); + iter != state_->iter_stack().rend(); ++iter) { + auto& frame = *iter; + auto frame_iter = frame.find(name); + if (frame_iter != frame.end()) { + if (frame_iter->second.has_value()) { + *val = frame_iter->second.value(); + return true; + } + } + } + + return false; +} + +std::unique_ptr CelExpressionFlatImpl::InitializeState( + google::protobuf::Arena* arena) const { + return absl::make_unique( + path_.size(), iter_variable_names_, arena); +} + cel_base::StatusOr CelExpressionFlatImpl::Evaluate( - const BaseActivation& activation, google::protobuf::Arena* arena) const { - return Trace(activation, arena, CelEvaluationListener()); + const BaseActivation& activation, CelEvaluationState* state) const { + return Trace(activation, state, CelEvaluationListener()); } cel_base::StatusOr CelExpressionFlatImpl::Trace( - const BaseActivation& activation, google::protobuf::Arena* arena, + const BaseActivation& activation, CelEvaluationState* _state, CelEvaluationListener callback) const { - ExecutionFrame frame(&path_, activation, arena, max_iterations_); + auto state = down_cast(_state); + state->Reset(); + + // Using both unknown attribute patterns and unknown paths via FieldMask is + // not allowed. + if (activation.unknown_paths().paths_size() != 0 && + !activation.unknown_attribute_patterns().empty()) { + return absl::InvalidArgumentError( + "Attempting to evaluate expression with both unknown_paths and " + "unknown_attribute_patterns set in the Activation"); + } + + ExecutionFrame frame(path_, activation, max_iterations_, state, + enable_unknowns_, enable_unknown_function_results_); ValueStack* stack = &frame.value_stack(); size_t initial_stack_size = stack->size(); @@ -48,7 +164,7 @@ cel_base::StatusOr CelExpressionFlatImpl::Trace( "Try to disable short-circuiting."; continue; } - auto status2 = callback(expr->id(), stack->Peek(), arena); + auto status2 = callback(expr->id(), stack->Peek(), state->arena()); if (!status2.ok()) { return status2; } @@ -56,7 +172,7 @@ cel_base::StatusOr CelExpressionFlatImpl::Trace( size_t final_stack_size = stack->size(); if (initial_stack_size + 1 != final_stack_size || final_stack_size == 0) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Stack error during evaluation"); } CelValue value = stack->Peek(); diff --git a/eval/eval/evaluator_core.h b/eval/eval/evaluator_core.h index 50d2d4c0d..475613a0e 100644 --- a/eval/eval/evaluator_core.h +++ b/eval/eval/evaluator_core.h @@ -1,10 +1,19 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EVALUATOR_CORE_H_ +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "absl/types/optional.h" +#include "eval/eval/attribute_trail.h" +#include "eval/eval/unknowns_utility.h" #include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -26,7 +35,7 @@ class ExpressionStep { // interface. // ExpressionStep instances can in specific cases // modify execution order(perform jumps). - virtual cel_base::Status Evaluate(ExecutionFrame* context) const = 0; + virtual absl::Status Evaluate(ExecutionFrame* context) const = 0; // Returns corresponding expression object ID. // Requires that the input expression has IDs assigned to sub-expressions, @@ -40,9 +49,6 @@ class ExpressionStep { virtual bool ComesFromAst() const = 0; }; -// CelValue stack. -// Implementation is based on vector to allow passing parameters from -// stack as Span<>. using ExecutionPath = std::vector>; // CelValue stack. @@ -50,42 +56,144 @@ using ExecutionPath = std::vector>; // stack as Span<>. class ValueStack { public: - ValueStack() = default; + ValueStack(size_t max_size) : current_size_(0) { + stack_.resize(max_size); + attribute_stack_.resize(max_size); + } + + // Return the current stack size. + size_t size() const { return current_size_; } + + // Return the maximum size of the stack. + size_t max_size() const { return stack_.size(); } - // Stack size. - size_t size() const { return stack_.size(); } + // Returns true if stack is empty. + bool empty() const { return current_size_ == 0; } + + // Attributes stack size. + size_t attribute_size() const { return current_size_; } // Check that stack has enough elements. - bool HasEnough(size_t size) const { return stack_.size() >= size; } + bool HasEnough(size_t size) const { return current_size_ >= size; } + + // Dumps the entire stack state as is. + void Clear(); // Gets the last size elements of the stack. // Checking that stack has enough elements is caller's responsibility. // Please note that calls to Push may invalidate returned Span object. absl::Span GetSpan(size_t size) const { - return absl::Span(stack_.data() + stack_.size() - size, + if (!HasEnough(size)) { + GOOGLE_LOG(ERROR) << "Requested span size (" << size + << ") exceeds current stack size: " << current_size_; + } + return absl::Span(stack_.data() + current_size_ - size, size); } + // Gets the last size attribute trails of the stack. + // Checking that stack has enough elements is caller's responsibility. + // Please note that calls to Push may invalidate returned Span object. + absl::Span GetAttributeSpan(size_t size) const { + return absl::Span( + attribute_stack_.data() + current_size_ - size, size); + } + // Peeks the last element of the stack. // Checking that stack is not empty is caller's responsibility. - const CelValue& Peek() const { return stack_.back(); } + const CelValue& Peek() const { + if (empty()) { + GOOGLE_LOG(ERROR) << "Peeking on empty ValueStack"; + } + return stack_[current_size_ - 1]; + } + + // Peeks the last element of the attribute stack. + // Checking that stack is not empty is caller's responsibility. + const AttributeTrail& PeekAttribute() const { + if (empty()) { + GOOGLE_LOG(ERROR) << "Peeking on empty ValueStack"; + } + return attribute_stack_[current_size_ - 1]; + } // Clears the last size elements of the stack. // Checking that stack has enough elements is caller's responsibility. - void Pop(size_t size) { stack_.resize(stack_.size() - size); } + void Pop(size_t size) { + if (!HasEnough(size)) { + GOOGLE_LOG(ERROR) << "Trying to pop more elements (" << size + << ") than the current stack size: " << current_size_; + } + current_size_ -= size; + } // Put element on the top of the stack. - void Push(const CelValue& value) { stack_.push_back(value); } + void Push(const CelValue& value) { Push(value, AttributeTrail()); } + + void Push(const CelValue& value, AttributeTrail attribute) { + if (current_size_ >= stack_.size()) { + GOOGLE_LOG(ERROR) << "No room to push more elements on to ValueStack"; + } + stack_[current_size_] = value; + attribute_stack_[current_size_] = attribute; + current_size_++; + } // Replace element on the top of the stack. // Checking that stack is not empty is caller's responsibility. - void PopAndPush(const CelValue& value) { stack_.back() = value; } + void PopAndPush(const CelValue& value) { + PopAndPush(value, AttributeTrail()); + } + + // Replace element on the top of the stack. + // Checking that stack is not empty is caller's responsibility. + void PopAndPush(const CelValue& value, AttributeTrail attribute) { + if (empty()) { + GOOGLE_LOG(ERROR) << "Cannot PopAndPush on empty stack."; + } + stack_[current_size_ - 1] = value; + attribute_stack_[current_size_ - 1] = attribute; + } // Preallocate stack. - void Reserve(size_t size) { stack_.reserve(size); } + void Reserve(size_t size) { + stack_.reserve(size); + attribute_stack_.reserve(size); + } private: std::vector stack_; + std::vector attribute_stack_; + size_t current_size_; +}; + +class CelExpressionFlatEvaluationState : public CelEvaluationState { + public: + CelExpressionFlatEvaluationState( + size_t value_stack_size, const std::set& iter_variable_names, + google::protobuf::Arena* arena); + + void Reset(); + + ValueStack& value_stack() { return value_stack_; } + + std::vector>>& iter_stack() { + return iter_stack_; + } + + std::map>& IterStackTop() { + return iter_stack_[iter_stack().size() - 1]; + } + + std::set& iter_variable_names() { return iter_variable_names_; } + + google::protobuf::Arena* arena() { return arena_; } + + private: + ValueStack value_stack_; + std::set iter_variable_names_; + std::vector>> iter_stack_; + google::protobuf::Arena* arena_; }; // ExecutionFrame provides context for expression evaluation. @@ -95,68 +203,90 @@ class ExecutionFrame { // flat is the flattened sequence of execution steps that will be evaluated. // activation provides bindings between parameter names and values. // arena serves as allocation manager during the expression evaluation. - ExecutionFrame(const ExecutionPath* flat, const BaseActivation& activation, - google::protobuf::Arena* arena, int max_iterations) + + ExecutionFrame(const ExecutionPath& flat, const BaseActivation& activation, + int max_iterations, CelExpressionFlatEvaluationState* state, + bool enable_unknowns, bool enable_unknown_function_results) : pc_(0UL), execution_path_(flat), activation_(activation), - arena_(arena), + enable_unknowns_(enable_unknowns), + enable_unknown_function_results_(enable_unknown_function_results), + unknowns_utility_(&activation.unknown_attribute_patterns(), + state->arena()), max_iterations_(max_iterations), - iterations_(0) { - // Reserve space on stack to minimize reallocations - // on stack resize. - value_stack_.Reserve(flat->size()); - } + iterations_(0), + state_(state) {} // Returns next expression to evaluate. const ExpressionStep* Next(); // Intended for use only in conditionals. - cel_base::Status JumpTo(int offset) { + absl::Status JumpTo(int offset) { int new_pc = static_cast(pc_) + offset; - if (new_pc < 0 || new_pc > static_cast(execution_path_->size())) { - return cel_base::Status(cel_base::StatusCode::kInternal, + if (new_pc < 0 || new_pc > static_cast(execution_path_.size())) { + return absl::Status(absl::StatusCode::kInternal, absl::StrCat("Jump address out of range: position: ", pc_, ",offset: ", offset, - ", range: ", execution_path_->size())); + ", range: ", execution_path_.size())); } pc_ = static_cast(new_pc); - return cel_base::OkStatus(); + return absl::OkStatus(); } - ValueStack& value_stack() { return value_stack_; } + ValueStack& value_stack() { return state_->value_stack(); } + bool enable_unknowns() const { return enable_unknowns_; } + bool enable_unknown_function_results() const { + return enable_unknown_function_results_; + } - google::protobuf::Arena* arena() { return arena_; } + google::protobuf::Arena* arena() { return state_->arena(); } + const UnknownsUtility& unknowns_utility() const { return unknowns_utility_; } // Returns reference to Activation const BaseActivation& activation() const { return activation_; } - // Returns reference to iter_vars - std::map& iter_vars() { return iter_vars_; } + // Creates a new frame for iteration variables. + absl::Status PushIterFrame(); + + // Discards the top frame for iteration variables. + absl::Status PopIterFrame(); + + // Sets the value of an iteration variable + absl::Status SetIterVar(const std::string& name, const CelValue& val); + + // Clears the value of an iteration variable + absl::Status ClearIterVar(const std::string& name); + + // Gets the current value of an iteration variable. + // Returns false if the variable is not currently in use (Set has been called + // since init or last clear). + bool GetIterVar(const std::string& name, CelValue* val) const; // Increment iterations and return an error if the iteration budget is // exceeded - cel_base::Status IncrementIterations() { + absl::Status IncrementIterations() { if (max_iterations_ == 0) { - return cel_base::OkStatus(); + return absl::OkStatus(); } iterations_++; if (iterations_ >= max_iterations_) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Iteration budget exceeded"); } - return cel_base::OkStatus(); + return absl::OkStatus(); } private: size_t pc_; // pc_ - Program Counter. Current position on execution path. - const ExecutionPath* execution_path_; + const ExecutionPath& execution_path_; const BaseActivation& activation_; - ValueStack value_stack_; - google::protobuf::Arena* arena_; + bool enable_unknowns_; + bool enable_unknown_function_results_; + UnknownsUtility unknowns_utility_; const int max_iterations_; int iterations_; - std::map iter_vars_; // variables declared in the frame. + CelExpressionFlatEvaluationState* state_; }; // Implementation of the CelExpression that utilizes flattening @@ -168,22 +298,50 @@ class CelExpressionFlatImpl : public CelExpression { // flattened AST tree. Max iterations dictates the maximum number of // iterations in the comprehension expressions (use 0 to disable the upper // bound). - CelExpressionFlatImpl(const google::api::expr::v1alpha1::Expr*, ExecutionPath path, - int max_iterations) - : path_(std::move(path)), max_iterations_(max_iterations) {} + CelExpressionFlatImpl(const google::api::expr::v1alpha1::Expr* root_expr, + ExecutionPath path, int max_iterations, + std::set iter_variable_names, + bool enable_unknowns = false, + bool enable_unknown_function_results = false) + : path_(std::move(path)), + max_iterations_(max_iterations), + iter_variable_names_(std::move(iter_variable_names)), + enable_unknowns_(enable_unknowns), + enable_unknown_function_results_(enable_unknown_function_results) {} + + // Move-only + CelExpressionFlatImpl(const CelExpressionFlatImpl&) = delete; + CelExpressionFlatImpl& operator=(const CelExpressionFlatImpl&) = delete; + + std::unique_ptr InitializeState( + google::protobuf::Arena* arena) const override; // Implementation of CelExpression evaluate method. cel_base::StatusOr Evaluate(const BaseActivation& activation, - google::protobuf::Arena* arena) const override; + google::protobuf::Arena* arena) const override { + return Evaluate(activation, InitializeState(arena).get()); + } + + cel_base::StatusOr Evaluate(const BaseActivation& activation, + CelEvaluationState* state) const override; // Implementation of CelExpression trace method. + cel_base::StatusOr Trace( + const BaseActivation& activation, google::protobuf::Arena* arena, + CelEvaluationListener callback) const override { + return Trace(activation, InitializeState(arena).get(), callback); + } + cel_base::StatusOr Trace(const BaseActivation& activation, - google::protobuf::Arena* arena, + CelEvaluationState* state, CelEvaluationListener callback) const override; private: const ExecutionPath path_; const int max_iterations_; + const std::set iter_variable_names_; + bool enable_unknowns_; + bool enable_unknown_function_results_; }; } // namespace runtime diff --git a/eval/eval/evaluator_core_test.cc b/eval/eval/evaluator_core_test.cc index d8eae3cf5..70a32c41e 100644 --- a/eval/eval/evaluator_core_test.cc +++ b/eval/eval/evaluator_core_test.cc @@ -1,10 +1,13 @@ #include "eval/eval/evaluator_core.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" -#include "eval/compiler/flat_expr_builder.h" -#include "eval/public/builtin_func_registrar.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/compiler/flat_expr_builder.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -13,15 +16,16 @@ namespace runtime { using ::google::api::expr::runtime::RegisterBuiltinFunctions; using testing::_; -using testing::Eq; // Optional ::testing aliases. Remove if unused. +using testing::Eq; +using testing::NotNull; // Fake expression implementation // Pushes int64_t(0) on top of value stack. class FakeConstExpressionStep : public ExpressionStep { public: - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { frame->value_stack().Push(CelValue::CreateInt64(0)); - return cel_base::OkStatus(); + return absl::OkStatus(); } int64_t id() const override { return 0; } @@ -33,13 +37,13 @@ class FakeConstExpressionStep : public ExpressionStep { // Increments argument on top of the stack. class FakeIncrementExpressionStep : public ExpressionStep { public: - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { CelValue value = frame->value_stack().Peek(); frame->value_stack().Pop(1); EXPECT_TRUE(value.IsInt64()); int64_t val = value.Int64OrDie(); frame->value_stack().Push(CelValue::CreateInt64(val + 1)); - return cel_base::OkStatus(); + return absl::OkStatus(); } int64_t id() const override { return 0; } @@ -47,6 +51,52 @@ class FakeIncrementExpressionStep : public ExpressionStep { bool ComesFromAst() const override { return true; } }; +// Test Value Stack Push/Pop operation +TEST(EvaluatorCoreTest, ValueStackPushPop) { + google::protobuf::Arena arena; + google::api::expr::v1alpha1::Expr expr; + expr.mutable_ident_expr()->set_name("name"); + CelAttribute attribute(expr, {}); + ValueStack stack(10); + stack.Push(CelValue::CreateInt64(1)); + stack.Push(CelValue::CreateInt64(2), AttributeTrail()); + stack.Push(CelValue::CreateInt64(3), AttributeTrail(expr, &arena)); + + ASSERT_EQ(stack.Peek().Int64OrDie(), 3); + ASSERT_THAT(stack.PeekAttribute().attribute(), NotNull()); + ASSERT_EQ(*stack.PeekAttribute().attribute(), attribute); + + stack.Pop(1); + + ASSERT_EQ(stack.Peek().Int64OrDie(), 2); + ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); + + stack.Pop(1); + + ASSERT_EQ(stack.Peek().Int64OrDie(), 1); + ASSERT_EQ(stack.PeekAttribute().attribute(), nullptr); +} + +// Test that inner stacks within value stack retain the equality of their sizes. +TEST(EvaluatorCoreTest, ValueStackBalanced) { + ValueStack stack(10); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.Push(CelValue::CreateInt64(1)); + ASSERT_EQ(stack.size(), stack.attribute_size()); + stack.Push(CelValue::CreateInt64(2), AttributeTrail()); + stack.Push(CelValue::CreateInt64(3), AttributeTrail()); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.PopAndPush(CelValue::CreateInt64(4), AttributeTrail()); + ASSERT_EQ(stack.size(), stack.attribute_size()); + stack.PopAndPush(CelValue::CreateInt64(5)); + ASSERT_EQ(stack.size(), stack.attribute_size()); + + stack.Pop(3); + ASSERT_EQ(stack.size(), stack.attribute_size()); +} + TEST(EvaluatorCoreTest, ExecutionFrameNext) { ExecutionPath path; auto const_step = absl::make_unique(); @@ -60,7 +110,8 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { auto dummy_expr = absl::make_unique(); Activation activation; - ExecutionFrame frame(&path, activation, nullptr, 0); + CelExpressionFlatEvaluationState state(path.size(), {}, nullptr); + ExecutionFrame frame(path, activation, 0, &state, false, false); EXPECT_THAT(frame.Next(), Eq(path[0].get())); EXPECT_THAT(frame.Next(), Eq(path[1].get())); @@ -68,6 +119,55 @@ TEST(EvaluatorCoreTest, ExecutionFrameNext) { EXPECT_THAT(frame.Next(), Eq(nullptr)); } +// Test the set, get, and clear functions for "IterVar" on ExecutionFrame +TEST(EvaluatorCoreTest, ExecutionFrameSetGetClearVar) { + const std::string test_key = "test_key"; + const int64_t test_value = 0xF00F00; + + Activation activation; + ExecutionPath path; + CelExpressionFlatEvaluationState state(path.size(), {test_key}, nullptr); + ExecutionFrame frame(path, activation, 0, &state, false, false); + + CelValue original = CelValue::CreateInt64(test_value); + CelValue result; + + ASSERT_OK(frame.PushIterFrame()); + + // Nothing is there yet + ASSERT_FALSE(frame.GetIterVar(test_key, &result)); + ASSERT_OK(frame.SetIterVar(test_key, original)); + + // Make sure its now there + ASSERT_TRUE(frame.GetIterVar(test_key, &result)); + + int64_t result_value; + ASSERT_TRUE(result.GetValue(&result_value)); + EXPECT_EQ(test_value, result_value); + + // Test that it goes away properly + ASSERT_OK(frame.ClearIterVar(test_key)); + ASSERT_FALSE(frame.GetIterVar(test_key, &result)); + + // Test that bogus names return the right thing + ASSERT_FALSE(frame.SetIterVar("foo", original).ok()); + ASSERT_FALSE(frame.ClearIterVar("bar").ok()); + + // Test error conditions for accesses outside of comprehension. + ASSERT_OK(frame.SetIterVar(test_key, original)); + ASSERT_OK(frame.PopIterFrame()); + + // Access on empty stack ok, but no value. + ASSERT_FALSE(frame.GetIterVar(test_key, &result)); + + // Pop empty stack + ASSERT_FALSE(frame.PopIterFrame().ok()); + + // Updates on empty stack not ok. + ASSERT_FALSE(frame.SetIterVar(test_key, original).ok()); + ASSERT_FALSE(frame.ClearIterVar(test_key).ok()); +} + TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { ExecutionPath path; auto const_step = absl::make_unique(); @@ -80,13 +180,13 @@ TEST(EvaluatorCoreTest, SimpleEvaluatorTest) { auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; google::protobuf::Arena arena; auto status = impl.Evaluate(activation, &arena); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); auto value = status.ValueOrDie(); EXPECT_TRUE(value.IsInt64()); @@ -158,10 +258,10 @@ TEST(EvaluatorCoreTest, TraceTest) { FlatExprBuilder builder; auto builtin_status = RegisterBuiltinFunctions(builder.GetRegistry()); - ASSERT_TRUE(builtin_status.ok()); + ASSERT_OK(builtin_status); builder.set_shortcircuiting(false); auto build_status = builder.CreateExpression(&expr, &source_info); - ASSERT_TRUE(build_status.ok()); + ASSERT_OK(build_status); auto cel_expr = std::move(build_status.ValueOrDie()); @@ -191,9 +291,9 @@ TEST(EvaluatorCoreTest, TraceTest) { activation, &arena, [&](int64_t expr_id, const CelValue& value, google::protobuf::Arena* arena) { callback.Call(expr_id, value, arena); - return ::cel_base::OkStatus(); + return absl::OkStatus(); }); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); } } // namespace runtime diff --git a/eval/eval/expression_build_warning.cc b/eval/eval/expression_build_warning.cc new file mode 100644 index 000000000..59a54651a --- /dev/null +++ b/eval/eval/expression_build_warning.cc @@ -0,0 +1,19 @@ +#include "eval/eval/expression_build_warning.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +absl::Status BuilderWarnings::AddWarning(const absl::Status& warning) { + if (fail_immediately_) { + return warning; + } + warnings_.push_back(warning); + return absl::OkStatus(); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/expression_build_warning.h b/eval/eval/expression_build_warning.h new file mode 100644 index 000000000..20575abe9 --- /dev/null +++ b/eval/eval/expression_build_warning.h @@ -0,0 +1,36 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ + +#include + +#include "absl/status/status.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Container for recording warnings. +class BuilderWarnings { + public: + explicit BuilderWarnings(bool fail_immediately = false) + : fail_immediately_(fail_immediately) {} + + // Add a warning. Returns the util:Status immediately if fail on warning is + // set. + absl::Status AddWarning(const absl::Status& warning); + + // Return the list of recorded warnings. + const std::vector& warnings() const { return warnings_; } + + private: + std::vector warnings_; + bool fail_immediately_; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_EXPRESSION_BUILD_WARNING_H_ diff --git a/eval/eval/expression_build_warning_test.cc b/eval/eval/expression_build_warning_test.cc new file mode 100644 index 000000000..212b2e5ae --- /dev/null +++ b/eval/eval/expression_build_warning_test.cc @@ -0,0 +1,36 @@ +#include "eval/eval/expression_build_warning.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/status/status.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + + +TEST(BuilderWarnings, NoFailCollects) { + BuilderWarnings warnings(false); + + auto status = warnings.AddWarning(absl::InternalError("internal")); + EXPECT_TRUE(status.ok()); + auto status2 = warnings.AddWarning(absl::InternalError("internal error 2")); + EXPECT_TRUE(status2.ok()); + + EXPECT_THAT(warnings.warnings(), testing::SizeIs(2)); +} + +TEST(BuilderWarnings, FailReturnsStatus) { + BuilderWarnings warnings(true); + + EXPECT_EQ(warnings.AddWarning(absl::InternalError("internal")).code(), + absl::StatusCode::kInternal); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/field_access.cc b/eval/eval/field_access.cc index e7921a4f4..316698af1 100644 --- a/eval/eval/field_access.cc +++ b/eval/eval/field_access.cc @@ -3,10 +3,10 @@ #include #include "google/protobuf/map_field.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "internal/proto_util.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -73,7 +73,7 @@ class FieldAccessor { // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. // result pointer to object to store value in. - cel_base::Status CreateValueFromFieldAccessor(Arena* arena, CelValue* result) { + absl::Status CreateValueFromFieldAccessor(Arena* arena, CelValue* result) { switch (field_desc_->cpp_type()) { case FieldDescriptor::CPPTYPE_BOOL: { bool value = GetBool(); @@ -124,7 +124,7 @@ class FieldAccessor { *result = CelValue::CreateBytes(value); break; default: - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Error handling C++ string conversion"); } break; @@ -140,11 +140,11 @@ class FieldAccessor { break; } default: - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Unhandled C++ type conversion"); } - return cel_base::OkStatus(); + return absl::OkStatus(); } protected: @@ -322,7 +322,7 @@ class MessageRetrieverOp { } // namespace -cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg, +absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, CelValue* result) { @@ -330,7 +330,7 @@ cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg return accessor.CreateValueFromFieldAccessor(arena, result); } -cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, +absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result) { @@ -338,7 +338,7 @@ cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* m return accessor.CreateValueFromFieldAccessor(arena, result); } -cel_base::Status CreateValueFromMapValue(const google::protobuf::Message* msg, +absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const FieldDescriptor* desc, const MapValueRef* value_ref, google::protobuf::Arena* arena, CelValue* result) { @@ -700,33 +700,32 @@ class RepeatedFieldSetter : public FieldSetter { // If value provided successfully, returns Ok. // arena Arena to use for allocations if needed. // result pointer to object to store value in. -::cel_base::Status SetValueToSingleField(const CelValue& value, - const FieldDescriptor* desc, - Message* msg) { +absl::Status SetValueToSingleField(const CelValue& value, + const FieldDescriptor* desc, Message* msg) { ScalarFieldSetter setter(msg, desc); return (setter.SetFieldFromCelValue(value)) - ? ::cel_base::OkStatus() - : ::cel_base::InvalidArgumentError(absl::Substitute( + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( "Could not assign supplied argument to message \"$0\" field " "\"$1\" of type $2: type was $3", msg->GetDescriptor()->name(), desc->name(), desc->type_name(), absl::StrCat(value.type()))); } -::cel_base::Status AddValueToRepeatedField(const CelValue& value, - const FieldDescriptor* desc, - Message* msg) { +absl::Status AddValueToRepeatedField(const CelValue& value, + const FieldDescriptor* desc, + Message* msg) { RepeatedFieldSetter setter(msg, desc); return (setter.SetFieldFromCelValue(value)) - ? ::cel_base::OkStatus() - : ::cel_base::InvalidArgumentError(absl::Substitute( + ? absl::OkStatus() + : absl::InvalidArgumentError(absl::Substitute( "Could not add supplied argument to message \"$0\" field " "\"$1\".", msg->GetDescriptor()->name(), desc->name())); } -::cel_base::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const FieldDescriptor* desc, Message* msg) { +absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, + const FieldDescriptor* desc, Message* msg) { auto entry_msg = msg->GetReflection()->AddMessage(msg, desc); auto key_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(1); auto value_field_desc = entry_msg->GetDescriptor()->FindFieldByNumber(2); @@ -735,20 +734,20 @@ ::cel_base::Status AddValueToMapField(const CelValue& key, const CelValue& value ScalarFieldSetter value_setter(entry_msg, value_field_desc); if (!key_setter.SetFieldFromCelValue(key)) { - return ::cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( absl::Substitute("Could not assign supplied argument to message \"$0\" " "field \"$1\" map key.", msg->GetDescriptor()->name(), desc->name())); } if (!value_setter.SetFieldFromCelValue(value)) { - return ::cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( absl::Substitute("Could not assign supplied argument to message \"$0\" " "field \"$1\" map value.", msg->GetDescriptor()->name(), desc->name())); } - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/eval/field_access.h b/eval/eval/field_access.h index 45deb22f6..63ed38369 100644 --- a/eval/eval/field_access.h +++ b/eval/eval/field_access.h @@ -14,7 +14,7 @@ namespace runtime { // desc Descriptor of the field to access. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. -cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg, +absl::Status CreateValueFromSingleField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, google::protobuf::Arena* arena, CelValue* result); @@ -25,7 +25,7 @@ cel_base::Status CreateValueFromSingleField(const google::protobuf::Message* msg // arena Arena object to allocate result on, if needed. // index position in the repeated field. // result pointer to CelValue to store the result in. -cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, +absl::Status CreateValueFromRepeatedField(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, google::protobuf::Arena* arena, int index, CelValue* result); @@ -37,7 +37,7 @@ cel_base::Status CreateValueFromRepeatedField(const google::protobuf::Message* m // value_ref pointer to map value. // arena Arena object to allocate result on, if needed. // result pointer to CelValue to store the result in. -cel_base::Status CreateValueFromMapValue(const google::protobuf::Message* msg, +absl::Status CreateValueFromMapValue(const google::protobuf::Message* msg, const google::protobuf::FieldDescriptor* desc, const google::protobuf::MapValueRef* value_ref, google::protobuf::Arena* arena, CelValue* result); @@ -46,25 +46,25 @@ cel_base::Status CreateValueFromMapValue(const google::protobuf::Message* msg, // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. -::cel_base::Status SetValueToSingleField(const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); +absl::Status SetValueToSingleField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg); // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. -::cel_base::Status AddValueToRepeatedField(const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); +absl::Status AddValueToRepeatedField(const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg); // Adds content of CelValue to repeated message field. // Returns status of the operation. // msg Message containing the field. // desc Descriptor of the field to access. -::cel_base::Status AddValueToMapField(const CelValue& key, const CelValue& value, - const google::protobuf::FieldDescriptor* desc, - google::protobuf::Message* msg); +absl::Status AddValueToMapField(const CelValue& key, const CelValue& value, + const google::protobuf::FieldDescriptor* desc, + google::protobuf::Message* msg); } // namespace runtime } // namespace expr diff --git a/eval/eval/function_step.cc b/eval/eval/function_step.cc index d0afdbc7b..76c787cce 100644 --- a/eval/eval/function_step.cc +++ b/eval/eval/function_step.cc @@ -5,11 +5,17 @@ #include #include +#include "google/protobuf/arena.h" #include "absl/strings/str_cat.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" #include "eval/eval/expression_step_base.h" #include "eval/public/cel_function_provider.h" #include "eval/public/cel_function_registry.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_function_result_set.h" +#include "eval/public/unknown_set.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -27,7 +33,10 @@ class AbstractFunctionStep : public ExpressionStepBase { AbstractFunctionStep(size_t num_arguments, int64_t expr_id) : ExpressionStepBase(expr_id), num_arguments_(num_arguments) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; + + absl::Status DoEvaluate(ExecutionFrame* frame, CelValue* result) const; + virtual cel_base::StatusOr ResolveFunction( absl::Span args, const ExecutionFrame* frame) const = 0; @@ -35,14 +44,18 @@ class AbstractFunctionStep : public ExpressionStepBase { size_t num_arguments_; }; -cel_base::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { - if (!frame->value_stack().HasEnough(num_arguments_)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); - } - +absl::Status AbstractFunctionStep::DoEvaluate(ExecutionFrame* frame, + CelValue* result) const { // Create Span object that contains input arguments to the function. auto input_args = frame->value_stack().GetSpan(num_arguments_); + const UnknownSet* unknown_set = nullptr; + if (frame->enable_unknowns()) { + auto input_attrs = frame->value_stack().GetAttributeSpan(num_arguments_); + unknown_set = frame->unknowns_utility().MergeUnknowns( + input_args, input_attrs, /*initial_set=*/nullptr, /*use_partial=*/true); + } + // Derived class resolves to a single function overload or none. auto status = ResolveFunction(input_args, frame); if (!status.ok()) { @@ -50,15 +63,23 @@ cel_base::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { } const CelFunction* matched_function = status.ValueOrDie(); - CelValue result = CelValue::CreateNull(); - // Overload found - if (matched_function != nullptr) { - cel_base::Status status = - matched_function->Evaluate(input_args, &result, frame->arena()); + if (matched_function != nullptr && unknown_set == nullptr) { + absl::Status status = + matched_function->Evaluate(input_args, result, frame->arena()); if (!status.ok()) { return status; } + if (frame->enable_unknown_function_results() && + IsUnknownFunctionResult(*result)) { + const auto* function_result = + google::protobuf::Arena::Create( + frame->arena(), matched_function->descriptor(), id(), + std::vector(input_args.begin(), input_args.end())); + const auto* unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownFunctionResultSet(function_result)); + *result = CelValue::CreateUnknownSet(unknown_set); + } } else { // No matching overloads. // We should not treat absense of overloads as non-recoverable error. @@ -67,20 +88,40 @@ cel_base::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { // should be propagated along execution path. for (const CelValue& arg : input_args) { if (arg.IsError()) { - result = arg; - break; + *result = arg; + return absl::OkStatus(); } } + + if (unknown_set) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + // If no errors in input args, create new CelError. - if (!result.IsError()) { - result = CreateNoMatchingOverloadError(frame->arena()); + if (!result->IsError()) { + *result = CreateNoMatchingOverloadError(frame->arena()); } } - frame->value_stack().Pop(input_args.length()); + return absl::OkStatus(); +} + +absl::Status AbstractFunctionStep::Evaluate(ExecutionFrame* frame) const { + if (!frame->value_stack().HasEnough(num_arguments_)) { + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); + } + + CelValue result; + auto status = DoEvaluate(frame, &result); + if (!status.ok()) { + return status; + } + + frame->value_stack().Pop(num_arguments_); frame->value_stack().Push(result); - return cel_base::OkStatus(); + return absl::OkStatus(); } class EagerFunctionStep : public AbstractFunctionStep { @@ -105,7 +146,7 @@ cel_base::StatusOr EagerFunctionStep::ResolveFunction( if (overload->MatchArguments(input_args)) { // More than one overload matches our arguments. if (matched_function != nullptr) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Cannot resolve overloads"); } @@ -159,7 +200,7 @@ cel_base::StatusOr LazyFunctionStep::ResolveFunction( if (overload != nullptr && overload->MatchArguments(input_args)) { // More than one overload matches our arguments. if (matched_function != nullptr) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Cannot resolve overloads"); } @@ -174,7 +215,8 @@ cel_base::StatusOr LazyFunctionStep::ResolveFunction( cel_base::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call_expr, int64_t expr_id, - const CelFunctionRegistry& function_registry) { + const CelFunctionRegistry& function_registry, + BuilderWarnings* builder_warnings) { bool receiver_style = call_expr->has_target(); size_t num_args = call_expr->args_size() + (receiver_style ? 1 : 0); const std::string& name = call_expr->function(); @@ -192,15 +234,16 @@ cel_base::StatusOr> CreateFunctionStep( auto overloads = function_registry.FindOverloads(name, receiver_style, args); - if (!overloads.empty()) { - std::unique_ptr step = absl::make_unique( - std::move(overloads), num_args, expr_id); - return std::move(step); + // No overloads found. + if (overloads.empty()) { + RETURN_IF_ERROR(builder_warnings->AddWarning( + absl::Status(absl::StatusCode::kInvalidArgument, + "No overloads provided for FunctionStep creation"))); } - // No overloads found. - return ::cel_base::Status(cel_base::StatusCode::kInvalidArgument, - "No overloads provided for FunctionStep creation"); + std::unique_ptr step = absl::make_unique( + std::move(overloads), num_args, expr_id); + return std::move(step); } } // namespace runtime diff --git a/eval/eval/function_step.h b/eval/eval/function_step.h index 5ed34885d..12facbb63 100644 --- a/eval/eval/function_step.h +++ b/eval/eval/function_step.h @@ -2,6 +2,7 @@ #define THIRD_PARTY_CEL_CPP_EVAL_EVAL_FUNCTION_STEP_H_ #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" #include "eval/public/activation.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" @@ -16,7 +17,8 @@ namespace runtime { // Looks up function registry using data provided through Call parameter. cel_base::StatusOr> CreateFunctionStep( const google::api::expr::v1alpha1::Expr::Call* call, int64_t expr_id, - const CelFunctionRegistry& function_registry); + const CelFunctionRegistry& function_registry, + BuilderWarnings* builder_warnings); } // namespace runtime } // namespace expr diff --git a/eval/eval/function_step_test.cc b/eval/eval/function_step_test.cc index d4659e988..8b55ab721 100644 --- a/eval/eval/function_step_test.cc +++ b/eval/eval/function_step_test.cc @@ -6,7 +6,15 @@ #include "absl/memory/memory.h" #include "absl/strings/string_view.h" #include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_build_warning.h" +#include "eval/eval/ident_step.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_function_result_set.h" +#include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -15,8 +23,10 @@ namespace runtime { namespace { +using testing::ElementsAre; using testing::Eq; using testing::Not; +using testing::UnorderedElementsAre; using google::api::expr::v1alpha1::Expr; @@ -26,6 +36,7 @@ int GetExprId() { return id; } +// Simple function that takes no arguments and returns a constant value. class ConstFunction : public CelFunction { public: explicit ConstFunction(const CelValue& value, absl::string_view name) @@ -42,24 +53,30 @@ class ConstFunction : public CelFunction { return call; } - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (!args.empty()) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Bad arguments number"); } *result = value_; - return cel_base::OkStatus(); + return absl::OkStatus(); } private: CelValue value_; }; +enum class ShouldReturnUnknown : bool { kYes = true, kNo = false }; + class AddFunction : public CelFunction { public: - AddFunction() : CelFunction(CreateDescriptor()) {} + AddFunction() + : CelFunction(CreateDescriptor()), should_return_unknown_(false) {} + explicit AddFunction(ShouldReturnUnknown should_return_unknown) + : CelFunction(CreateDescriptor()), + should_return_unknown_(static_cast(should_return_unknown)) {} static CelFunctionDescriptor CreateDescriptor() { return CelFunctionDescriptor{ @@ -75,23 +92,56 @@ class AddFunction : public CelFunction { return call; } - cel_base::Status Evaluate(absl::Span args, CelValue* result, + absl::Status Evaluate(absl::Span args, CelValue* result, google::protobuf::Arena* arena) const override { if (args.size() != 2 || !args[0].IsInt64() || !args[1].IsInt64()) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Mismatched arguments passed to method"); } + if (should_return_unknown_) { + *result = + CreateUnknownFunctionResultError(arena, "Add can't be resolved."); + return absl::OkStatus(); + } int64_t arg0 = args[0].Int64OrDie(); int64_t arg1 = args[1].Int64OrDie(); *result = CelValue::CreateInt64(arg0 + arg1); - return cel_base::OkStatus(); + return absl::OkStatus(); + } + + private: + bool should_return_unknown_; +}; + +class SinkFunction : public CelFunction { + public: + SinkFunction(CelValue::Type type) : CelFunction(CreateDescriptor(type)) {} + + static CelFunctionDescriptor CreateDescriptor(CelValue::Type type) { + return CelFunctionDescriptor{"Sink", false, {type}}; + } + + static Expr::Call MakeCall() { + Expr::Call call; + call.set_function("Sink"); + call.add_args(); + call.clear_target(); + return call; + } + + absl::Status Evaluate(absl::Span args, CelValue* result, + google::protobuf::Arena* arena) const override { + // Return value is ignored. + *result = CelValue::CreateInt64(0); + return absl::OkStatus(); } }; // Create and initialize a registry with some default functions. void AddDefaults(CelFunctionRegistry& registry) { + static UnknownSet* unknown_set = new UnknownSet(); EXPECT_TRUE(registry .Register(absl::make_unique( CelValue::CreateInt64(3), "Const3")) @@ -100,11 +150,60 @@ void AddDefaults(CelFunctionRegistry& registry) { .Register(absl::make_unique( CelValue::CreateInt64(2), "Const2")) .ok()); + EXPECT_TRUE(registry + .Register(absl::make_unique( + CelValue::CreateUnknownSet(unknown_set), "ConstUnknown")) + .ok()); EXPECT_TRUE(registry.Register(absl::make_unique()).ok()); + + EXPECT_TRUE( + registry.Register(absl::make_unique(CelValue::Type::kList)) + .ok()); + + EXPECT_TRUE( + registry.Register(absl::make_unique(CelValue::Type::kMap)) + .ok()); + + EXPECT_TRUE( + registry + .Register(absl::make_unique(CelValue::Type::kMessage)) + .ok()); } -TEST(FunctionStepTest, SimpleFunctionTest) { +// Test common functions with varying levels of unknown support. +class FunctionStepTest + : public testing::TestWithParam { + public: + // underlying expression impl moves path + std::unique_ptr GetExpression(ExecutionPath&& path) { + bool unknowns; + bool unknown_function_results; + switch (GetParam()) { + case UnknownProcessingOptions::kAttributeAndFunction: + unknowns = true; + unknown_function_results = true; + break; + case UnknownProcessingOptions::kAttributeOnly: + unknowns = true; + unknown_function_results = false; + break; + case UnknownProcessingOptions::kDisabled: + unknowns = false; + unknown_function_results = false; + break; + } + return absl::make_unique( + &dummy_expr_, std::move(path), 0, std::set(), unknowns, + unknown_function_results); + } + + private: + Expr dummy_expr_; +}; + +TEST_P(FunctionStepTest, SimpleFunctionTest) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -113,27 +212,28 @@ TEST(FunctionStepTest, SimpleFunctionTest) { Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - EXPECT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -141,8 +241,9 @@ TEST(FunctionStepTest, SimpleFunctionTest) { EXPECT_THAT(value.Int64OrDie(), Eq(5)); } -TEST(FunctionStepTest, TestStackUnderflow) { +TEST_P(FunctionStepTest, TestStackUnderflow) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -152,40 +253,74 @@ TEST(FunctionStepTest, TestStackUnderflow) { Expr::Call call1 = ConstFunction::MakeCall("Const3"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); + auto status = impl->Evaluate(activation, &arena); EXPECT_FALSE(status.ok()); } -// Test factory method when empty overload set is provided. +// Test that creation fails if fail on warnings is set in the warnings +// collection. TEST(FunctionStepTest, TestNoOverloadsOnCreation) { - CelFunctionRegistry registry; // SetupRegistry(); + CelFunctionRegistry registry; + BuilderWarnings warnings(true); + Expr::Call call = ConstFunction::MakeCall("Const0"); // function step with empty overloads - auto step0_status = CreateFunctionStep(&call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call, GetExprId(), registry, &warnings); EXPECT_FALSE(step0_status.ok()); } +// Test that no overloads error is warned, actual error delayed to runtime by +// default. +TEST_P(FunctionStepTest, TestNoOverloadsOnCreationDelayedError) { + CelFunctionRegistry registry; + ExecutionPath path; + Expr::Call call = ConstFunction::MakeCall("Const0"); + BuilderWarnings warnings; + + // function step with empty overloads + auto step0_status = + CreateFunctionStep(&call, GetExprId(), registry, &warnings); + + EXPECT_TRUE(step0_status.ok()); + EXPECT_THAT(warnings.warnings(), testing::SizeIs(1)); + + path.push_back(std::move(step0_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + ASSERT_TRUE(value.IsError()); +} + // Test situation when no overloads match input arguments during evaluation. -TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { +TEST_P(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -200,37 +335,39 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluation) { // Add expects {int64_t, int64_t} but it's {int64_t, uint64_t}. Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); ASSERT_TRUE(value.IsError()); } - // Test situation when no overloads match input arguments during evaluation // and at least one of arguments is error. -TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { +TEST_P(FunctionStepTest, + TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { ExecutionPath path; + BuilderWarnings warnings; CelFunctionRegistry registry; AddDefaults(registry); @@ -252,27 +389,28 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); Activation activation; google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -280,49 +418,51 @@ TEST(FunctionStepTest, TestNoMatchingOverloadsDuringEvaluationErrorForwarding) { EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); } -TEST(FunctionStepTest, LazyFunctionTest) { +TEST_P(FunctionStepTest, LazyFunctionTest) { ExecutionPath path; Activation activation; CelFunctionRegistry registry; + BuilderWarnings warnings; auto register0_status = registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const3")); - EXPECT_TRUE(register0_status.ok()); + ASSERT_OK(register0_status); auto insert0_status = activation.InsertFunction( absl::make_unique(CelValue::CreateInt64(3), "Const3")); - EXPECT_TRUE(insert0_status.ok()); + ASSERT_OK(insert0_status); auto register1_status = registry.RegisterLazyFunction(ConstFunction::CreateDescriptor("Const2")); - EXPECT_TRUE(register1_status.ok()); + ASSERT_OK(register1_status); auto insert1_status = activation.InsertFunction( absl::make_unique(CelValue::CreateInt64(2), "Const2")); - EXPECT_TRUE(insert1_status.ok()); - EXPECT_TRUE(registry.Register(absl::make_unique()).ok()); + ASSERT_OK(insert1_status); + ASSERT_OK(registry.Register(absl::make_unique())); Expr::Call call1 = ConstFunction::MakeCall("Const3"); Expr::Call call2 = ConstFunction::MakeCall("Const2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); - - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + std::unique_ptr impl = GetExpression(std::move(path)); google::protobuf::Arena arena; - auto status = impl.Evaluate(activation, &arena); - EXPECT_TRUE(status.ok()); + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); @@ -332,12 +472,14 @@ TEST(FunctionStepTest, LazyFunctionTest) { // Test situation when no overloads match input arguments during evaluation // and at least one of arguments is error. -TEST(FunctionStepTest, - TestNoMatchingOverloadsDuringEvaluationErrorForwardingLazy) { +TEST_P(FunctionStepTest, + TestNoMatchingOverloadsDuringEvaluationErrorForwardingLazy) { ExecutionPath path; Activation activation; google::protobuf::Arena arena; CelFunctionRegistry registry; + BuilderWarnings warnings; + AddDefaults(registry); CelError error0; @@ -346,46 +488,483 @@ TEST(FunctionStepTest, // Constants have ERROR type, while AddFunction expects INT. auto register0_status = registry.RegisterLazyFunction( ConstFunction::CreateDescriptor("ConstError1")); - ASSERT_TRUE(register0_status.ok()); + ASSERT_OK(register0_status); auto insert0_status = activation.InsertFunction(absl::make_unique( CelValue::CreateError(&error0), "ConstError1")); - ASSERT_TRUE(insert0_status.ok()); + ASSERT_OK(insert0_status); auto register1_status = registry.RegisterLazyFunction( ConstFunction::CreateDescriptor("ConstError2")); - ASSERT_TRUE(register1_status.ok()); + ASSERT_OK(register1_status); auto insert1_status = activation.InsertFunction(absl::make_unique( CelValue::CreateError(&error1), "ConstError2")); - ASSERT_TRUE(insert1_status.ok()); + ASSERT_OK(insert1_status); Expr::Call call1 = ConstFunction::MakeCall("ConstError1"); Expr::Call call2 = ConstFunction::MakeCall("ConstError2"); Expr::Call add_call = AddFunction::MakeCall(); - auto step0_status = CreateFunctionStep(&call1, GetExprId(), registry); - auto step1_status = CreateFunctionStep(&call2, GetExprId(), registry); - auto step2_status = CreateFunctionStep(&add_call, GetExprId(), registry); + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); - ASSERT_TRUE(step2_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); path.push_back(std::move(step2_status.ValueOrDie())); - auto dummy_expr = absl::make_unique(); + std::unique_ptr impl = GetExpression(std::move(path)); + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsError()); + EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); +} + +std::string TestNameFn(testing::TestParamInfo opt) { + switch (opt.param) { + case UnknownProcessingOptions::kDisabled: + return "disabled"; + case UnknownProcessingOptions::kAttributeOnly: + return "attribute_only"; + case UnknownProcessingOptions::kAttributeAndFunction: + return "attribute_and_function"; + } +} - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); +INSTANTIATE_TEST_SUITE_P( + UnknownSupport, FunctionStepTest, + testing::Values(UnknownProcessingOptions::kDisabled, + UnknownProcessingOptions::kAttributeOnly, + UnknownProcessingOptions::kAttributeAndFunction), + &TestNameFn); + +class FunctionStepTestUnknowns + : public testing::TestWithParam { + public: + std::unique_ptr GetExpression(ExecutionPath&& path) { + bool unknown_functions; + switch (GetParam()) { + case UnknownProcessingOptions::kAttributeAndFunction: + unknown_functions = true; + break; + default: + unknown_functions = false; + break; + } + return absl::make_unique(&expr, std::move(path), 0, + std::set(), + true, unknown_functions); + } + + private: + Expr expr; +}; + +TEST_P(FunctionStepTestUnknowns, PassedUnknownTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + AddDefaults(registry); + + Expr::Call call1 = ConstFunction::MakeCall("Const3"); + Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); +} + +TEST_P(FunctionStepTestUnknowns, PartialUnknownHandlingTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + AddDefaults(registry); + + // Build the expression path that corresponds to CEL expression + // "sink(param)". + Expr::Ident ident1; + ident1.set_name("param"); + Expr::Call call1 = SinkFunction::MakeCall(); + + auto step0_status = CreateIdentStep(&ident1, GetExprId()); + auto step1_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + TestMessage msg; + google::protobuf::Arena arena; + activation.InsertValue("param", CelValue::CreateMessage(&msg, &arena)); + CelAttributePattern pattern( + "param", + {CelAttributeQualifierPattern::Create(CelValue::CreateBool(true))}); + + // Set attribute pattern that marks attribute "param[true]" as unknown. + // It should result in "param" being handled as partially unknown, which is + // is handled as fully unknown when used as function input argument. + activation.set_unknown_attribute_patterns({pattern}); + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); +} + +TEST_P(FunctionStepTestUnknowns, UnknownVsErrorPrecedenceTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + AddDefaults(registry); + + CelError error0; + CelValue error_value = CelValue::CreateError(&error0); + + ASSERT_TRUE( + registry + .Register(absl::make_unique(error_value, "ConstError")) + .ok()); + + Expr::Call call1 = ConstFunction::MakeCall("ConstError"); + Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + std::unique_ptr impl = GetExpression(std::move(path)); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl->Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsError()); + // Making sure we propagate the error. + ASSERT_EQ(value.ErrorOrDie(), error_value.ErrorOrDie()); +} + +INSTANTIATE_TEST_SUITE_P( + UnknownFunctionSupport, FunctionStepTestUnknowns, + testing::Values(UnknownProcessingOptions::kAttributeOnly, + UnknownProcessingOptions::kAttributeAndFunction), + &TestNameFn); + +MATCHER_P2(IsAdd, a, b, "") { + const UnknownFunctionResult* result = arg; + return result->arguments().size() == 2 && + result->arguments().at(0).IsInt64() && + result->arguments().at(1).IsInt64() && + result->arguments().at(0).Int64OrDie() == a && + result->arguments().at(1).Int64OrDie() == b && + result->descriptor().name() == "_+_"; +} + +TEST(FunctionStepTestUnknownFunctionResults, CaptureArgs) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + Expr::Call call1 = ConstFunction::MakeCall("Const2"); + Expr::Call call2 = ConstFunction::MakeCall("Const3"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl.Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); + // Arguments captured. + EXPECT_THAT(value.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(IsAdd(2, 3))); +} + +TEST(FunctionStepTestUnknownFunctionResults, MergeDownCaptureArgs) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + // Add(Add(2, 3), Add(2, 3)) + + Expr::Call call1 = ConstFunction::MakeCall("Const2"); + Expr::Call call2 = ConstFunction::MakeCall("Const3"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step3_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step4_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step5_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step6_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + ASSERT_OK(step3_status); + ASSERT_OK(step4_status); + ASSERT_OK(step5_status); + ASSERT_OK(step6_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + path.push_back(std::move(step3_status.ValueOrDie())); + path.push_back(std::move(step4_status.ValueOrDie())); + path.push_back(std::move(step5_status.ValueOrDie())); + path.push_back(std::move(step6_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; auto status = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()); + // Arguments captured. + EXPECT_THAT(value.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(IsAdd(2, 3))); +} + +TEST(FunctionStepTestUnknownFunctionResults, MergeCaptureArgs) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(2), "Const2"))); + ASSERT_OK(registry.Register( + absl::make_unique(CelValue::CreateInt64(3), "Const3"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + // Add(Add(2, 3), Add(3, 2)) + + Expr::Call call1 = ConstFunction::MakeCall("Const2"); + Expr::Call call2 = ConstFunction::MakeCall("Const3"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step3_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step4_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step5_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + auto step6_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + ASSERT_OK(step3_status); + ASSERT_OK(step4_status); + ASSERT_OK(step5_status); + ASSERT_OK(step6_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + path.push_back(std::move(step3_status.ValueOrDie())); + path.push_back(std::move(step4_status.ValueOrDie())); + path.push_back(std::move(step5_status.ValueOrDie())); + path.push_back(std::move(step6_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl.Evaluate(activation, &arena); + ASSERT_OK(status); + + auto value = status.ValueOrDie(); + + ASSERT_TRUE(value.IsUnknownSet()) << value.ErrorOrDie()->ToString(); + // Arguments captured. + EXPECT_THAT(value.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + UnorderedElementsAre(IsAdd(2, 3), IsAdd(3, 2))); +} + +TEST(FunctionStepTestUnknownFunctionResults, UnknownVsErrorPrecedenceTest) { + ExecutionPath path; + BuilderWarnings warnings; + + CelFunctionRegistry registry; + + CelError error0; + CelValue error_value = CelValue::CreateError(&error0); + UnknownSet unknown_set; + CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); + + ASSERT_OK(registry.Register( + absl::make_unique(error_value, "ConstError"))); + ASSERT_OK(registry.Register( + absl::make_unique(unknown_value, "ConstUnknown"))); + ASSERT_OK(registry.Register( + absl::make_unique(ShouldReturnUnknown::kYes))); + + Expr::Call call1 = ConstFunction::MakeCall("ConstError"); + Expr::Call call2 = ConstFunction::MakeCall("ConstUnknown"); + Expr::Call add_call = AddFunction::MakeCall(); + + auto step0_status = + CreateFunctionStep(&call1, GetExprId(), registry, &warnings); + auto step1_status = + CreateFunctionStep(&call2, GetExprId(), registry, &warnings); + auto step2_status = + CreateFunctionStep(&add_call, GetExprId(), registry, &warnings); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + ASSERT_OK(step2_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + path.push_back(std::move(step2_status.ValueOrDie())); + + Expr dummy_expr; + + CelExpressionFlatImpl impl(&dummy_expr, std::move(path), 0, {}, true, true); + + Activation activation; + google::protobuf::Arena arena; + + auto status = impl.Evaluate(activation, &arena); + ASSERT_OK(status); auto value = status.ValueOrDie(); ASSERT_TRUE(value.IsError()); - EXPECT_THAT(value.ErrorOrDie(), Eq(&error0)); + // Making sure we propagate the error. + ASSERT_EQ(value.ErrorOrDie(), error_value.ErrorOrDie()); } } // namespace diff --git a/eval/eval/ident_step.cc b/eval/eval/ident_step.cc index 9a7565378..178941c71 100644 --- a/eval/eval/ident_step.cc +++ b/eval/eval/ident_step.cc @@ -1,6 +1,10 @@ #include "eval/eval/ident_step.h" -#include "eval/eval/expression_step_base.h" + +#include "google/protobuf/arena.h" #include "absl/strings/substitute.h" +#include "eval/eval/evaluator_core.h" +#include "eval/eval/expression_step_base.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -13,41 +17,69 @@ class IdentStep : public ExpressionStepBase { IdentStep(absl::string_view name, int64_t expr_id) : ExpressionStepBase(expr_id), name_(name) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: + void DoEvaluate(ExecutionFrame* frame, CelValue* result, + AttributeTrail* trail) const; + std::string name_; }; -cel_base::Status IdentStep::Evaluate(ExecutionFrame* frame) const { - CelValue result; - auto it = frame->iter_vars().find(name_); - if (it != frame->iter_vars().end()) { - result = it->second; - } else { - auto value = frame->activation().FindValue(name_, frame->arena()); +void IdentStep::DoEvaluate(ExecutionFrame* frame, CelValue* result, + AttributeTrail* trail) const { + // Special case - iterator looked up in + if (frame->GetIterVar(name_, result)) { + return; + } + + auto value = frame->activation().FindValue(name_, frame->arena()); + { // We handle masked unknown paths for the sake of uniformity, although it is // better not to bind unknown values to activation in first place. + // TODO(issues/41) Deprecate this style of unknowns handling after + // Unknowns are properly supported. bool unknown_value = frame->activation().IsPathUnknown(name_); - if (!unknown_value) { - if (value.has_value()) { - result = value.value(); - } else { - result = CreateErrorValue( - frame->arena(), - absl::Substitute("No value with name \"$0\" found in Activation", - name_)); - } - } else { - result = CreateUnknownValueError(frame->arena(), name_); + if (unknown_value) { + *result = CreateUnknownValueError(frame->arena(), name_); + return; } } - frame->value_stack().Push(result); + if (frame->enable_unknowns()) { + google::api::expr::v1alpha1::Expr expr; + expr.mutable_ident_expr()->set_name(name_); + *trail = AttributeTrail(expr, frame->arena()); + + if (frame->unknowns_utility().CheckForUnknown(*trail, false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({trail->attribute()})); + *result = CelValue::CreateUnknownSet(unknown_set); + return; + } + } + + if (value.has_value()) { + *result = value.value(); + } else { + *result = CreateErrorValue( + frame->arena(), + absl::Substitute("No value with name \"$0\" found in Activation", + name_)); + } +} + +absl::Status IdentStep::Evaluate(ExecutionFrame* frame) const { + CelValue result; + AttributeTrail trail; + + DoEvaluate(frame, &result, &trail); + + frame->value_stack().Push(result, trail); - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace diff --git a/eval/eval/ident_step_test.cc b/eval/eval/ident_step_test.cc index 5b7a8a9f3..a77744cbe 100644 --- a/eval/eval/ident_step_test.cc +++ b/eval/eval/ident_step_test.cc @@ -1,9 +1,10 @@ #include "eval/eval/ident_step.h" -#include "eval/eval/evaluator_core.h" -#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/eval/evaluator_core.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -24,14 +25,14 @@ TEST(IdentStepTest, TestIdentStep) { ident_expr->set_name("name0"); auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); ExecutionPath path; path.push_back(std::move(step_status.ValueOrDie())); auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; Arena arena; @@ -39,7 +40,7 @@ TEST(IdentStepTest, TestIdentStep) { activation.InsertValue("name0", CelValue::CreateString(&value)); auto status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); CelValue result = status0.ValueOrDie(); @@ -53,40 +54,40 @@ TEST(IdentStepTest, TestIdentStepNameNotFound) { ident_expr->set_name("name0"); auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); ExecutionPath path; path.push_back(std::move(step_status.ValueOrDie())); auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; Arena arena; std::string value("test"); auto status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); CelValue result = status0.ValueOrDie(); ASSERT_TRUE(result.IsError()); } -TEST(IdentStepTest, TestIdentStepUnknownValue) { +TEST(IdentStepTest, TestIdentStepUnknownValueError) { Expr expr; auto ident_expr = expr.mutable_ident_expr(); ident_expr->set_name("name0"); auto step_status = CreateIdentStep(ident_expr, expr.id()); - ASSERT_TRUE(step_status.ok()); + ASSERT_OK(step_status); ExecutionPath path; path.push_back(std::move(step_status.ValueOrDie())); auto dummy_expr = absl::make_unique(); - CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0); + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}); Activation activation; Arena arena; @@ -94,7 +95,7 @@ TEST(IdentStepTest, TestIdentStepUnknownValue) { activation.InsertValue("name0", CelValue::CreateString(&value)); auto status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); CelValue result = status0.ValueOrDie(); @@ -106,7 +107,7 @@ TEST(IdentStepTest, TestIdentStepUnknownValue) { activation.set_unknown_paths(unknown_mask); status0 = impl.Evaluate(activation, &arena); - ASSERT_TRUE(status0.ok()); + ASSERT_OK(status0); result = status0.ValueOrDie(); @@ -115,6 +116,50 @@ TEST(IdentStepTest, TestIdentStepUnknownValue) { EXPECT_THAT(GetUnknownPathsSetOrDie(result), Eq(std::set({"name0"}))); } +TEST(IdentStepTest, TestIdentStepUnknownAttribute) { + Expr expr; + auto ident_expr = expr.mutable_ident_expr(); + ident_expr->set_name("name0"); + + auto step_status = CreateIdentStep(ident_expr, expr.id()); + ASSERT_OK(step_status); + + ExecutionPath path; + path.push_back(std::move(step_status.ValueOrDie())); + + auto dummy_expr = absl::make_unique(); + + // Expression with unknowns enabled. + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, true); + + Activation activation; + Arena arena; + std::string value("test"); + + activation.InsertValue("name0", CelValue::CreateString(&value)); + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern("name_bad", {})); + + activation.set_unknown_attribute_patterns(unknown_patterns); + auto status0 = impl.Evaluate(activation, &arena); + ASSERT_OK(status0); + + CelValue result = status0.ValueOrDie(); + + ASSERT_TRUE(result.IsString()); + EXPECT_THAT(result.StringOrDie().value(), Eq("test")); + + unknown_patterns.push_back(CelAttributePattern("name0", {})); + + activation.set_unknown_attribute_patterns(unknown_patterns); + status0 = impl.Evaluate(activation, &arena); + ASSERT_OK(status0); + + result = status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); +} + } // namespace } // namespace runtime diff --git a/eval/eval/jump_step.cc b/eval/eval/jump_step.cc index 0988dbda1..908fea38c 100644 --- a/eval/eval/jump_step.cc +++ b/eval/eval/jump_step.cc @@ -14,7 +14,7 @@ class JumpStep : public JumpStepBase { JumpStep(absl::optional jump_offset, int64_t expr_id) : JumpStepBase(jump_offset, expr_id) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { return Jump(frame); } }; @@ -28,10 +28,10 @@ class CondJumpStep : public JumpStepBase { jump_condition_(jump_condition), leave_on_stack_(leave_on_stack) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { // Peek the top value if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue value = frame->value_stack().Peek(); @@ -44,7 +44,7 @@ class CondJumpStep : public JumpStepBase { return Jump(frame); } - return cel_base::OkStatus(); + return absl::OkStatus(); } private: @@ -57,15 +57,16 @@ class BoolCheckJumpStep : public JumpStepBase { // Checks if the top value is a boolean: // - no-op if it is a boolean // - jump to the label if it is an error value + // - jump to the label if it is unknown value // - jump to the label if it is neither an error nor a boolean, pops it and // pushes "no matching overload" error BoolCheckJumpStep(absl::optional jump_offset, int64_t expr_id) : JumpStepBase(jump_offset, expr_id) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override { + absl::Status Evaluate(ExecutionFrame* frame) const override { // Peek the top value if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } CelValue value = frame->value_stack().Peek(); @@ -74,13 +75,17 @@ class BoolCheckJumpStep : public JumpStepBase { return Jump(frame); } + if (value.IsUnknownSet()) { + return Jump(frame); + } + if (!value.IsBool()) { CelValue error_value = CreateNoMatchingOverloadError(frame->arena()); frame->value_stack().PopAndPush(error_value); return Jump(frame); } - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -109,7 +114,7 @@ cel_base::StatusOr> CreateJumpStep( // Factory method for Conditional Jump step. // Conditional Jump requires a value to sit on the stack. -// If this value is an error, a jump is performed. +// If this value is an error or unknown, a jump is performed. cel_base::StatusOr> CreateBoolCheckJumpStep( absl::optional jump_offset, int64_t expr_id) { std::unique_ptr step = @@ -118,6 +123,9 @@ cel_base::StatusOr> CreateBoolCheckJumpStep( return std::move(step); } +// TODO(issues/41) Make sure Unknowns are properly supported by ternary +// operation. + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/eval/jump_step.h b/eval/eval/jump_step.h index fd474b763..3515c8bb8 100644 --- a/eval/eval/jump_step.h +++ b/eval/eval/jump_step.h @@ -19,9 +19,9 @@ class JumpStepBase : public ExpressionStepBase { void set_jump_offset(int offset) { jump_offset_ = offset; } - cel_base::Status Jump(ExecutionFrame* frame) const { + absl::Status Jump(ExecutionFrame* frame) const { if (!jump_offset_.has_value()) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Jump offset not set"); + return absl::Status(absl::StatusCode::kInternal, "Jump offset not set"); } return frame->JumpTo(jump_offset_.value()); } diff --git a/eval/eval/logic_step.cc b/eval/eval/logic_step.cc index 7103fb0ab..a267eefda 100644 --- a/eval/eval/logic_step.cc +++ b/eval/eval/logic_step.cc @@ -1,7 +1,9 @@ #include "eval/eval/logic_step.h" -#include "eval/eval/expression_step_base.h" #include "absl/strings/str_cat.h" +#include "eval/eval/expression_step_base.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" namespace google { namespace api { @@ -20,10 +22,10 @@ class LogicalOpStep : public ExpressionStepBase { shortcircuit_ = (op_type_ == OpType::OR); } - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status Calculate(ExecutionFrame* frame, absl::Span args, + absl::Status Calculate(ExecutionFrame* frame, absl::Span args, CelValue* result) const { bool bool_args[2]; bool has_bool_args[2]; @@ -32,7 +34,7 @@ class LogicalOpStep : public ExpressionStepBase { has_bool_args[i] = args[i].GetValue(bool_args + i); if (has_bool_args[i] && shortcircuit_ == bool_args[i]) { *result = CelValue::CreateBool(bool_args[i]); - return cel_base::OkStatus(); + return absl::OkStatus(); } } @@ -40,34 +42,51 @@ class LogicalOpStep : public ExpressionStepBase { switch (op_type_) { case OpType::AND: *result = CelValue::CreateBool(bool_args[0] && bool_args[1]); - return cel_base::OkStatus(); + return absl::OkStatus(); break; case OpType::OR: *result = CelValue::CreateBool(bool_args[0] || bool_args[1]); - return cel_base::OkStatus(); + return absl::OkStatus(); break; } } + // As opposed to regular function, logical operation treat Unknowns with + // higher precedence than error. This is due to the fact that after Unknown + // is resolved to actual value, it may shortcircuit and thus hide the error. + if (frame->enable_unknowns()) { + // Check if unknown? + const UnknownSet* unknown_set = + frame->unknowns_utility().MergeUnknowns(args, + /*initial_set=*/nullptr); + + if (unknown_set) { + *result = CelValue::CreateUnknownSet(unknown_set); + return absl::OkStatus(); + } + } + if (args[0].IsError()) { *result = args[0]; + return absl::OkStatus(); } else if (args[1].IsError()) { *result = args[1]; - } else { - *result = CreateNoMatchingOverloadError(frame->arena()); + return absl::OkStatus(); } - return cel_base::OkStatus(); + // Fallback. + *result = CreateNoMatchingOverloadError(frame->arena()); + return absl::OkStatus(); } const OpType op_type_; bool shortcircuit_; }; -cel_base::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { +absl::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { // Must have 2 or more values on the stack. if (!frame->value_stack().HasEnough(2)) { - return cel_base::Status(cel_base::StatusCode::kInternal, "Value stack underflow"); + return absl::Status(absl::StatusCode::kInternal, "Value stack underflow"); } // Create Span object that contains input arguments to the function. @@ -88,6 +107,9 @@ cel_base::Status LogicalOpStep::Evaluate(ExecutionFrame* frame) const { } // namespace +// TODO(issues/41) Make sure Unknowns are properly supported by ternary +// operation. + // Factory method for "And" Execution step cel_base::StatusOr> CreateAndStep(int64_t expr_id) { std::unique_ptr step = diff --git a/eval/eval/logic_step_test.cc b/eval/eval/logic_step_test.cc new file mode 100644 index 000000000..e9b410b1e --- /dev/null +++ b/eval/eval/logic_step_test.cc @@ -0,0 +1,319 @@ +#include "eval/eval/logic_step.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/eval/ident_step.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +namespace { + +using google::api::expr::v1alpha1::Expr; + +using google::protobuf::Arena; +using testing::Eq; +class LogicStepTest : public testing::TestWithParam { + public: + absl::Status EvaluateLogic(CelValue arg0, CelValue arg1, bool is_or, + CelValue* result, bool enable_unknown) { + Expr expr0; + auto ident_expr0 = expr0.mutable_ident_expr(); + ident_expr0->set_name("name0"); + + Expr expr1; + auto ident_expr1 = expr1.mutable_ident_expr(); + ident_expr1->set_name("name1"); + + ExecutionPath path; + + auto step_status = CreateIdentStep(ident_expr0, expr0.id()); + if (!step_status.ok()) return step_status.status(); + + path.push_back(std::move(step_status).ValueOrDie()); + + step_status = CreateIdentStep(ident_expr1, expr1.id()); + if (!step_status.ok()) return step_status.status(); + + path.push_back(std::move(step_status).ValueOrDie()); + + step_status = (is_or) ? CreateOrStep(2) : CreateAndStep(2); + if (!step_status.ok()) return step_status.status(); + + path.push_back(std::move(step_status).ValueOrDie()); + + auto dummy_expr = absl::make_unique(); + + CelExpressionFlatImpl impl(dummy_expr.get(), std::move(path), 0, {}, + enable_unknown); + + Activation activation; + std::string value("test"); + + activation.InsertValue("name0", arg0); + activation.InsertValue("name1", arg1); + auto status0 = impl.Evaluate(activation, &arena_); + if (!status0.ok()) return status0.status(); + + *result = status0.ValueOrDie(); + return absl::OkStatus(); + } + + private: + Arena arena_; +}; + +TEST_P(LogicStepTest, TestAndLogic) { + CelValue result; + absl::Status status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(true), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); +} + +TEST_P(LogicStepTest, TestOrLogic) { + CelValue result; + absl::Status status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(true), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(true), CelValue::CreateBool(false), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(CelValue::CreateBool(false), + CelValue::CreateBool(true), true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = + EvaluateLogic(CelValue::CreateBool(false), CelValue::CreateBool(false), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); +} + +TEST_P(LogicStepTest, TestAndLogicErrorHandling) { + CelValue result; + CelError error; + CelValue error_value = CelValue::CreateError(&error); + absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(true), + false, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(true), error_value, false, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(false), error_value, false, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = EvaluateLogic(error_value, CelValue::CreateBool(false), false, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); +} + +TEST_P(LogicStepTest, TestOrLogicErrorHandling) { + CelValue result; + CelError error; + CelValue error_value = CelValue::CreateError(&error); + absl::Status status = EvaluateLogic(error_value, CelValue::CreateBool(false), + true, &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(false), error_value, true, + &result, GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsError()); + + status = EvaluateLogic(CelValue::CreateBool(true), error_value, true, &result, + GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(error_value, CelValue::CreateBool(true), true, &result, + GetParam()); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); +} + +TEST_F(LogicStepTest, TestAndLogicUnknownHandling) { + CelValue result; + UnknownSet unknown_set; + CelError cel_error; + CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); + CelValue error_value = CelValue::CreateError(&cel_error); + absl::Status status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), + false, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = EvaluateLogic(unknown_value, CelValue::CreateBool(false), false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_FALSE(result.BoolOrDie()); + + status = EvaluateLogic(error_value, unknown_value, false, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(unknown_value, error_value, false, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + Expr expr0; + auto ident_expr0 = expr0.mutable_ident_expr(); + ident_expr0->set_name("name0"); + + Expr expr1; + auto ident_expr1 = expr1.mutable_ident_expr(); + ident_expr1->set_name("name1"); + + CelAttribute attr0(expr0, {}), attr1(expr1, {}); + UnknownAttributeSet unknown_attr_set0({&attr0}); + UnknownAttributeSet unknown_attr_set1({&attr1}); + UnknownSet unknown_set0(unknown_attr_set0); + UnknownSet unknown_set1(unknown_attr_set1); + + EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + + status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), + CelValue::CreateUnknownSet(&unknown_set1), false, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + ASSERT_THAT( + result.UnknownSetOrDie()->unknown_attributes().attributes().size(), + Eq(2)); +} + +TEST_F(LogicStepTest, TestOrLogicUnknownHandling) { + CelValue result; + UnknownSet unknown_set; + CelError cel_error; + CelValue unknown_value = CelValue::CreateUnknownSet(&unknown_set); + CelValue error_value = CelValue::CreateError(&cel_error); + absl::Status status = EvaluateLogic( + unknown_value, CelValue::CreateBool(false), true, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(false), unknown_value, true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(CelValue::CreateBool(true), unknown_value, true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(unknown_value, CelValue::CreateBool(true), true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsBool()); + ASSERT_TRUE(result.BoolOrDie()); + + status = EvaluateLogic(unknown_value, error_value, true, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + status = EvaluateLogic(error_value, unknown_value, true, &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + + Expr expr0; + auto ident_expr0 = expr0.mutable_ident_expr(); + ident_expr0->set_name("name0"); + + Expr expr1; + auto ident_expr1 = expr1.mutable_ident_expr(); + ident_expr1->set_name("name1"); + + CelAttribute attr0(expr0, {}), attr1(expr1, {}); + UnknownAttributeSet unknown_attr_set0({&attr0}); + UnknownAttributeSet unknown_attr_set1({&attr1}); + + UnknownSet unknown_set0(unknown_attr_set0); + UnknownSet unknown_set1(unknown_attr_set1); + + EXPECT_THAT(unknown_attr_set0.attributes().size(), Eq(1)); + EXPECT_THAT(unknown_attr_set1.attributes().size(), Eq(1)); + + status = EvaluateLogic(CelValue::CreateUnknownSet(&unknown_set0), + CelValue::CreateUnknownSet(&unknown_set1), true, + &result, true); + ASSERT_OK(status); + ASSERT_TRUE(result.IsUnknownSet()); + ASSERT_THAT( + result.UnknownSetOrDie()->unknown_attributes().attributes().size(), + Eq(2)); +} + +INSTANTIATE_TEST_SUITE_P(LogicStepTest, LogicStepTest, testing::Bool()); +} // namespace + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/select_step.cc b/eval/eval/select_step.cc index 60707bf68..effae14e0 100644 --- a/eval/eval/select_step.cc +++ b/eval/eval/select_step.cc @@ -1,9 +1,11 @@ #include "eval/eval/select_step.h" + +#include "absl/strings/str_cat.h" +#include "eval/eval/evaluator_core.h" #include "eval/eval/expression_step_base.h" #include "eval/eval/field_access.h" #include "eval/eval/field_backed_list_impl.h" #include "eval/eval/field_backed_map_impl.h" -#include "absl/strings/str_cat.h" namespace google { namespace api { @@ -12,9 +14,9 @@ namespace runtime { namespace { -using google::protobuf::Reflection; using google::protobuf::Descriptor; using google::protobuf::FieldDescriptor; +using google::protobuf::Reflection; // SelectStep performs message field access specified by Expr::Select // message. @@ -27,10 +29,10 @@ class SelectStep : public ExpressionStepBase { test_field_presence_(test_field_presence), select_path_(select_path) {} - cel_base::Status Evaluate(ExecutionFrame* frame) const override; + absl::Status Evaluate(ExecutionFrame* frame) const override; private: - cel_base::Status CreateValueFromField(const google::protobuf::Message* message, + absl::Status CreateValueFromField(const google::protobuf::Message* msg, google::protobuf::Arena* arena, CelValue* result) const; @@ -39,7 +41,7 @@ class SelectStep : public ExpressionStepBase { std::string select_path_; }; -cel_base::Status SelectStep::CreateValueFromField(const google::protobuf::Message* msg, +absl::Status SelectStep::CreateValueFromField(const google::protobuf::Message* msg, google::protobuf::Arena* arena, CelValue* result) const { const Reflection* reflection = msg->GetReflection(); @@ -48,36 +50,42 @@ cel_base::Status SelectStep::CreateValueFromField(const google::protobuf::Messag if (field_desc == nullptr) { *result = CreateNoSuchFieldError(arena); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (field_desc->is_map()) { *result = CelValue::CreateMap(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (field_desc->is_repeated()) { *result = CelValue::CreateList(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (test_field_presence_) { *result = CelValue::CreateBool(reflection->HasField(*msg, field_desc)); - return cel_base::OkStatus(); + return absl::OkStatus(); } return CreateValueFromSingleField(msg, field_desc, arena, result); } -cel_base::Status SelectStep::Evaluate(ExecutionFrame* frame) const { +absl::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (!frame->value_stack().HasEnough(1)) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "No arguments supplied for Select-type expression"); } - CelValue arg = frame->value_stack().Peek(); + const CelValue& arg = frame->value_stack().Peek(); + const AttributeTrail& trail = frame->value_stack().PeekAttribute(); + + CelValue result; + AttributeTrail result_trail; // Non-empty select path - check if value mapped to unknown. bool unknown_value = false; + // TODO(issues/41) deprecate this path after proper support of unknown is + // implemented if (!select_path_.empty()) { unknown_value = frame->activation().IsPathUnknown(select_path_); } @@ -87,24 +95,36 @@ cel_base::Status SelectStep::Evaluate(ExecutionFrame* frame) const { case CelValue::Type::kMessage: { const google::protobuf::Message* msg = arg.MessageOrDie(); + if (frame->enable_unknowns()) { + result_trail = trail.Step(&field_, frame->arena()); + if (frame->unknowns_utility().CheckForUnknown(result_trail, + /*use_partial=*/false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({result_trail.attribute()})); + result = CelValue::CreateUnknownSet(unknown_set); + frame->value_stack().PopAndPush(result, result_trail); + return absl::OkStatus(); + } + } + if (msg == nullptr) { CelValue error_value = CreateErrorValue(frame->arena(), "Message is NULL"); - frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + frame->value_stack().PopAndPush(error_value, result_trail); + return absl::OkStatus(); } if (unknown_value) { CelValue error_value = CreateUnknownValueError(frame->arena(), select_path_); - frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + frame->value_stack().PopAndPush(error_value, result_trail); + return absl::OkStatus(); } - cel_base::Status status = CreateValueFromField(msg, frame->arena(), &arg); + absl::Status status = CreateValueFromField(msg, frame->arena(), &result); if (status.ok()) { - frame->value_stack().PopAndPush(arg); + frame->value_stack().PopAndPush(result, result_trail); } return status; @@ -115,42 +135,57 @@ cel_base::Status SelectStep::Evaluate(ExecutionFrame* frame) const { if (cel_map == nullptr) { CelValue error_value = CreateErrorValue(frame->arena(), "Map is NULL"); frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + return absl::OkStatus(); } if (unknown_value) { CelValue error_value = CreateErrorValue( frame->arena(), absl::StrCat("Unknown value ", select_path_)); frame->value_stack().PopAndPush(error_value); - return cel_base::OkStatus(); + return absl::OkStatus(); } auto lookup_result = (*cel_map)[CelValue::CreateString(&field_)]; // Test only Select expression. if (test_field_presence_) { - arg = CelValue::CreateBool(lookup_result.has_value()); - frame->value_stack().PopAndPush(arg); - return cel_base::OkStatus(); + result = CelValue::CreateBool(lookup_result.has_value()); + frame->value_stack().PopAndPush(result); + return absl::OkStatus(); + } + + if (frame->enable_unknowns()) { + result_trail = trail.Step(&field_, frame->arena()); + if (frame->unknowns_utility().CheckForUnknown(result_trail, false)) { + auto unknown_set = google::protobuf::Arena::Create( + frame->arena(), UnknownAttributeSet({result_trail.attribute()})); + result = CelValue::CreateUnknownSet(unknown_set); + frame->value_stack().PopAndPush(result, result_trail); + return absl::OkStatus(); + } } // If object is not found, we return Error, per CEL specification. if (lookup_result) { - arg = lookup_result.value(); + result = lookup_result.value(); } else { - arg = CreateNoSuchKeyError(frame->arena(), field_); + result = CreateNoSuchKeyError(frame->arena(), field_); } - frame->value_stack().PopAndPush(arg); + frame->value_stack().PopAndPush(result, result_trail); - return ::cel_base::OkStatus(); + return absl::OkStatus(); + } + case CelValue::Type::kUnknownSet: { + // Parent is unknown already, bubble it up. + return absl::OkStatus(); } case CelValue::Type::kError: { // If argument is CelError, we propagate it forward. // It is already on the top of the stack. - return ::cel_base::OkStatus(); + return absl::OkStatus(); } default: - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Applying SELECT to non-message type"); } } diff --git a/eval/eval/select_step_test.cc b/eval/eval/select_step_test.cc index 4baf6b398..2a5cfa8f2 100644 --- a/eval/eval/select_step_test.cc +++ b/eval/eval/select_step_test.cc @@ -5,8 +5,11 @@ #include "gtest/gtest.h" #include "eval/eval/container_backed_map_impl.h" #include "eval/eval/ident_step.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/unknown_attribute_set.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -24,7 +27,8 @@ using google::api::expr::v1alpha1::Expr; cel_base::StatusOr RunExpression(const CelValue target, absl::string_view field, bool test, google::protobuf::Arena* arena, - absl::string_view unknown_path) { + absl::string_view unknown_path, + bool enable_unknowns) { ExecutionPath path; Expr dummy_expr; @@ -50,7 +54,8 @@ cel_base::StatusOr RunExpression(const CelValue target, path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + enable_unknowns); Activation activation; activation.InsertValue("target", target); @@ -60,50 +65,57 @@ cel_base::StatusOr RunExpression(const CelValue target, cel_base::StatusOr RunExpression(const TestMessage* message, absl::string_view field, bool test, google::protobuf::Arena* arena, - absl::string_view unknown_path) { + absl::string_view unknown_path, + bool enable_unknowns) { return RunExpression(CelValue::CreateMessage(message, arena), field, test, - arena, unknown_path); + arena, unknown_path, enable_unknowns); } cel_base::StatusOr RunExpression(const TestMessage* message, absl::string_view field, bool test, - google::protobuf::Arena* arena) { - return RunExpression(message, field, test, arena, ""); + google::protobuf::Arena* arena, + bool enable_unknowns) { + return RunExpression(message, field, test, arena, "", enable_unknowns); } cel_base::StatusOr RunExpression(const CelMap* map_value, absl::string_view field, bool test, google::protobuf::Arena* arena, - absl::string_view unknown_path) { + absl::string_view unknown_path, + bool enable_unknowns) { return RunExpression(CelValue::CreateMap(map_value), field, test, arena, - unknown_path); + unknown_path, enable_unknowns); } cel_base::StatusOr RunExpression(const CelMap* map_value, absl::string_view field, bool test, - google::protobuf::Arena* arena) { - return RunExpression(map_value, field, test, arena, ""); + google::protobuf::Arena* arena, + bool enable_unknowns) { + return RunExpression(map_value, field, test, arena, "", enable_unknowns); } -TEST(SelectStepTest, SelectMessageIsNull) { +class SelectStepTest : public testing::TestWithParam {}; + +TEST_P(SelectStepTest, SelectMessageIsNull) { google::protobuf::Arena arena; auto run_status = RunExpression(static_cast(nullptr), - "bool_value", true, &arena); - ASSERT_TRUE(run_status.ok()); + "bool_value", true, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); } -TEST(SelectStepTest, PresenseIsFalseTest) { +TEST_P(SelectStepTest, PresenseIsFalseTest) { TestMessage message; google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", true, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", true, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -111,14 +123,15 @@ TEST(SelectStepTest, PresenseIsFalseTest) { EXPECT_EQ(result.BoolOrDie(), false); } -TEST(SelectStepTest, PresenseIsTrueTest) { +TEST_P(SelectStepTest, PresenseIsTrueTest) { TestMessage message; message.set_bool_value(true); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", true, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", true, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -126,7 +139,7 @@ TEST(SelectStepTest, PresenseIsTrueTest) { EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, MapPresenseIsFalseTest) { +TEST_P(SelectStepTest, MapPresenseIsFalseTest) { std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; @@ -136,14 +149,15 @@ TEST(SelectStepTest, MapPresenseIsFalseTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(map_value.get(), "key2", true, &arena); + auto run_status = + RunExpression(map_value.get(), "key2", true, &arena, GetParam()); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), false); } -TEST(SelectStepTest, MapPresenseIsTrueTest) { +TEST_P(SelectStepTest, MapPresenseIsTrueTest) { std::string key1 = "key1"; std::vector> key_values{ {CelValue::CreateString(&key1), CelValue::CreateInt64(1)}}; @@ -153,34 +167,56 @@ TEST(SelectStepTest, MapPresenseIsTrueTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(map_value.get(), "key1", true, &arena); + auto run_status = + RunExpression(map_value.get(), "key1", true, &arena, GetParam()); + CelValue result = run_status.ValueOrDie(); + + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.BoolOrDie(), true); +} + +TEST(SelectStepTest, MapPresenseIsTrueWithUnknownTest) { + UnknownSet unknown_set; + std::string key1 = "key1"; + std::vector> key_values{ + {CelValue::CreateString(&key1), + CelValue::CreateUnknownSet(&unknown_set)}}; + + auto map_value = CreateContainerBackedMap( + absl::Span>(key_values)); + + google::protobuf::Arena arena; + + auto run_status = RunExpression(map_value.get(), "key1", true, &arena, true); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsBool()); EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, FieldIsNotPresentInProtoTest) { +TEST_P(SelectStepTest, FieldIsNotPresentInProtoTest) { TestMessage message; google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "fake_field", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "fake_field", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); ASSERT_TRUE(result.IsError()); - EXPECT_THAT(result.ErrorOrDie()->code(), Eq(cel_base::StatusCode::kNotFound)); + EXPECT_THAT(result.ErrorOrDie()->code(), Eq(absl::StatusCode::kNotFound)); } -TEST(SelectStepTest, FieldIsNotSetTest) { +TEST_P(SelectStepTest, FieldIsNotSetTest) { TestMessage message; google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -188,14 +224,15 @@ TEST(SelectStepTest, FieldIsNotSetTest) { EXPECT_EQ(result.BoolOrDie(), false); } -TEST(SelectStepTest, SimpleBoolTest) { +TEST_P(SelectStepTest, SimpleBoolTest) { TestMessage message; message.set_bool_value(true); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bool_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bool_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -203,14 +240,15 @@ TEST(SelectStepTest, SimpleBoolTest) { EXPECT_EQ(result.BoolOrDie(), true); } -TEST(SelectStepTest, SimpleInt32Test) { +TEST_P(SelectStepTest, SimpleInt32Test) { TestMessage message; message.set_int32_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "int32_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "int32_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -218,14 +256,15 @@ TEST(SelectStepTest, SimpleInt32Test) { EXPECT_EQ(result.Int64OrDie(), 1); } -TEST(SelectStepTest, SimpleInt64Test) { +TEST_P(SelectStepTest, SimpleInt64Test) { TestMessage message; message.set_int64_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "int64_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "int64_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -233,14 +272,15 @@ TEST(SelectStepTest, SimpleInt64Test) { EXPECT_EQ(result.Int64OrDie(), 1); } -TEST(SelectStepTest, SimpleUInt32Test) { +TEST_P(SelectStepTest, SimpleUInt32Test) { TestMessage message; message.set_uint32_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "uint32_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "uint32_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -248,14 +288,15 @@ TEST(SelectStepTest, SimpleUInt32Test) { EXPECT_EQ(result.Uint64OrDie(), 1); } -TEST(SelectStepTest, SimpleUint64Test) { +TEST_P(SelectStepTest, SimpleUint64Test) { TestMessage message; message.set_uint64_value(1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "uint64_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "uint64_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -263,15 +304,16 @@ TEST(SelectStepTest, SimpleUint64Test) { EXPECT_EQ(result.Uint64OrDie(), 1); } -TEST(SelectStepTest, SimpleStringTest) { +TEST_P(SelectStepTest, SimpleStringTest) { TestMessage message; std::string value = "test"; message.set_string_value(value); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "string_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "string_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -280,15 +322,16 @@ TEST(SelectStepTest, SimpleStringTest) { } -TEST(SelectStepTest, SimpleBytesTest) { +TEST_P(SelectStepTest, SimpleBytesTest) { TestMessage message; std::string value = "test"; message.set_bytes_value(value); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "bytes_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "bytes_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -296,7 +339,7 @@ TEST(SelectStepTest, SimpleBytesTest) { EXPECT_EQ(result.BytesOrDie().value(), "test"); } -TEST(SelectStepTest, SimpleMessageTest) { +TEST_P(SelectStepTest, SimpleMessageTest) { TestMessage message; TestMessage* message2 = message.mutable_message_value(); @@ -305,8 +348,9 @@ TEST(SelectStepTest, SimpleMessageTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "message_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "message_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -314,15 +358,16 @@ TEST(SelectStepTest, SimpleMessageTest) { EXPECT_THAT(*message2, EqualsProto(*result.MessageOrDie())); } -TEST(SelectStepTest, SimpleEnumTest) { +TEST_P(SelectStepTest, SimpleEnumTest) { TestMessage message; message.set_enum_value(TestMessage::TEST_ENUM_1); google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "enum_value", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "enum_value", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -330,7 +375,7 @@ TEST(SelectStepTest, SimpleEnumTest) { EXPECT_THAT(result.Int64OrDie(), Eq(TestMessage::TEST_ENUM_1)); } -TEST(SelectStepTest, SimpleListTest) { +TEST_P(SelectStepTest, SimpleListTest) { TestMessage message; message.add_int32_list(1); @@ -338,8 +383,9 @@ TEST(SelectStepTest, SimpleListTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "int32_list", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "int32_list", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -350,7 +396,7 @@ TEST(SelectStepTest, SimpleListTest) { EXPECT_THAT(cel_list->size(), Eq(2)); } -TEST(SelectStepTest, SimpleMapTest) { +TEST_P(SelectStepTest, SimpleMapTest) { TestMessage message; auto map_field = message.mutable_string_int32_map(); (*map_field)["test0"] = 1; @@ -358,8 +404,9 @@ TEST(SelectStepTest, SimpleMapTest) { google::protobuf::Arena arena; - auto run_status = RunExpression(&message, "string_int32_map", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(&message, "string_int32_map", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -370,7 +417,7 @@ TEST(SelectStepTest, SimpleMapTest) { EXPECT_THAT(cel_map->size(), Eq(2)); } -TEST(SelectStepTest, MapSimpleInt32Test) { +TEST_P(SelectStepTest, MapSimpleInt32Test) { std::string key1 = "key1"; std::string key2 = "key2"; std::vector> key_values{ @@ -382,8 +429,9 @@ TEST(SelectStepTest, MapSimpleInt32Test) { google::protobuf::Arena arena; - auto run_status = RunExpression(map_value.get(), "key1", false, &arena); - ASSERT_TRUE(run_status.ok()); + auto run_status = + RunExpression(map_value.get(), "key1", false, &arena, GetParam()); + ASSERT_OK(run_status); CelValue result = run_status.ValueOrDie(); @@ -392,7 +440,7 @@ TEST(SelectStepTest, MapSimpleInt32Test) { } // Test Select behavior, when expression to select from is an Error. -TEST(SelectStepTest, CelErrorAsArgument) { +TEST_P(SelectStepTest, CelErrorAsArgument) { ExecutionPath path; Expr dummy_expr; @@ -416,19 +464,20 @@ TEST(SelectStepTest, CelErrorAsArgument) { CelError error; google::protobuf::Arena arena; - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + GetParam()); Activation activation; activation.InsertValue("message", CelValue::CreateError(&error)); auto status = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); auto result = status.ValueOrDie(); ASSERT_TRUE(result.IsError()); EXPECT_THAT(result.ErrorOrDie(), Eq(&error)); } -TEST(SelectStepTest, UnknownValueProducesError) { +TEST_P(SelectStepTest, UnknownValueProducesError) { TestMessage message; message.set_bool_value(true); google::protobuf::Arena arena; @@ -447,18 +496,19 @@ TEST(SelectStepTest, UnknownValueProducesError) { auto step1_status = CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); - ASSERT_TRUE(step0_status.ok()); - ASSERT_TRUE(step1_status.ok()); + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); path.push_back(std::move(step0_status.ValueOrDie())); path.push_back(std::move(step1_status.ValueOrDie())); - CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0); + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, + GetParam()); Activation activation; activation.InsertValue("message", CelValue::CreateMessage(&message, &arena)); auto eval_status0 = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(eval_status0.ok()); + ASSERT_OK(eval_status0); CelValue result = eval_status0.ValueOrDie(); @@ -471,7 +521,7 @@ TEST(SelectStepTest, UnknownValueProducesError) { activation.set_unknown_paths(mask); auto eval_status1 = cel_expr.Evaluate(activation, &arena); - ASSERT_TRUE(eval_status1.ok()); + ASSERT_OK(eval_status1); result = eval_status1.ValueOrDie(); @@ -481,6 +531,124 @@ TEST(SelectStepTest, UnknownValueProducesError) { Eq(std::set({"message.bool_value"}))); } +TEST(SelectStepTest, UnknownPatternResolvesToUnknown) { + TestMessage message; + message.set_bool_value(true); + google::protobuf::Arena arena; + ExecutionPath path; + + Expr dummy_expr; + + auto select = dummy_expr.mutable_select_expr(); + select->set_field("bool_value"); + select->set_test_only(false); + Expr* expr0 = select->mutable_operand(); + + auto ident = expr0->mutable_ident_expr(); + ident->set_name("message"); + auto step0_status = CreateIdentStep(ident, expr0->id()); + auto step1_status = + CreateSelectStep(select, dummy_expr.id(), "message.bool_value"); + + ASSERT_OK(step0_status); + ASSERT_OK(step1_status); + + path.push_back(std::move(step0_status.ValueOrDie())); + path.push_back(std::move(step1_status.ValueOrDie())); + + CelExpressionFlatImpl cel_expr(&dummy_expr, std::move(path), 0, {}, true); + + { + std::vector unknown_patterns; + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.BoolOrDie(), true); + } + + const std::string kSegmentCorrect1 = "bool_value"; + const std::string kSegmentIncorrect = "message_value"; + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern("message", {})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); + } + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern( + "message", {CelAttributeQualifierPattern::Create( + CelValue::CreateString(&kSegmentCorrect1))})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); + } + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern( + "message", {CelAttributeQualifierPattern::CreateWildcard()})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsUnknownSet()); + } + + { + std::vector unknown_patterns; + unknown_patterns.push_back(CelAttributePattern( + "message", {CelAttributeQualifierPattern::Create( + CelValue::CreateString(&kSegmentIncorrect))})); + Activation activation; + activation.InsertValue("message", + CelValue::CreateMessage(&message, &arena)); + activation.set_unknown_attribute_patterns(unknown_patterns); + + auto eval_status0 = cel_expr.Evaluate(activation, &arena); + ASSERT_OK(eval_status0); + + CelValue result = eval_status0.ValueOrDie(); + + ASSERT_TRUE(result.IsBool()); + EXPECT_EQ(result.BoolOrDie(), true); + } +} + +INSTANTIATE_TEST_SUITE_P(SelectStepTest, SelectStepTest, testing::Bool()); } // namespace } // namespace runtime } // namespace expr diff --git a/eval/eval/unknowns_utility.cc b/eval/eval/unknowns_utility.cc new file mode 100644 index 000000000..4c9d122fe --- /dev/null +++ b/eval/eval/unknowns_utility.cc @@ -0,0 +1,96 @@ +#include "eval/eval/unknowns_utility.h" + +#include "absl/status/status.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" +#include "base/statusor.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +using google::protobuf::Arena; + +// Checks whether particular corresponds to any patterns that define unknowns. +bool UnknownsUtility::CheckForUnknown(const AttributeTrail& trail, + bool use_partial) const { + if (trail.empty()) { + return false; + } + for (const auto& pattern : *unknown_patterns_) { + auto current_match = pattern.IsMatch(*trail.attribute()); + if (current_match == CelAttributePattern::MatchType::FULL || + (use_partial && + current_match == CelAttributePattern::MatchType::PARTIAL)) { + return true; + } + } + return false; +} + +// Creates merged UnknownAttributeSet. +// Scans over the args collection, merges any UnknownSets found in +// it together with initial_set (if initial_set is not null). +// Returns pointer to merged set or nullptr, if there were no sets to merge. +const UnknownSet* UnknownsUtility::MergeUnknowns( + absl::Span args, const UnknownSet* initial_set) const { + const UnknownSet* result = initial_set; + + for (const auto& value : args) { + if (!value.IsUnknownSet()) continue; + + auto current_set = value.UnknownSetOrDie(); + if (result == nullptr) { + result = current_set; + } else { + result = Arena::Create(arena_, *result, *current_set); + } + } + + return result; +} + +// Creates merged UnknownAttributeSet. +// Scans over the args collection, determines if there matches to unknown +// patterns, merges attributes together with those from initial_set +// (if initial_set is not null). +// Returns pointer to merged set or nullptr, if there were no sets to merge. +UnknownAttributeSet UnknownsUtility::CheckForUnknowns( + absl::Span args, bool use_partial) const { + std::vector unknown_attrs; + + for (auto trail : args) { + if (CheckForUnknown(trail, use_partial)) { + unknown_attrs.push_back(trail.attribute()); + } + } + + return UnknownAttributeSet(unknown_attrs); +} + +// Creates merged UnknownAttributeSet. +// Merges together attributes from UnknownAttributeSets found in the args +// collection, attributes from attr that match unknown pattern +// patterns, and attributes from initial_set +// (if initial_set is not null). +// Returns pointer to merged set or nullptr, if there were no sets to merge. +const UnknownSet* UnknownsUtility::MergeUnknowns( + absl::Span args, absl::Span attrs, + const UnknownSet* initial_set, bool use_partial) const { + UnknownAttributeSet attr_set = CheckForUnknowns(attrs, use_partial); + if (!attr_set.attributes().empty()) { + if (initial_set != nullptr) { + initial_set = + Arena::Create(arena_, *initial_set, UnknownSet(attr_set)); + } else { + initial_set = Arena::Create(arena_, attr_set); + } + } + return MergeUnknowns(args, initial_set); +} +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/eval/unknowns_utility.h b/eval/eval/unknowns_utility.h new file mode 100644 index 000000000..5f0f7df6a --- /dev/null +++ b/eval/eval/unknowns_utility.h @@ -0,0 +1,68 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ + +#include +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "absl/types/optional.h" +#include "eval/eval/attribute_trail.h" +#include "eval/public/activation.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Helper class for handling unknowns logic. Provides helpers for merging +// unknown sets from arguments on the stack and for identifying unknown +// attributes based on the patterns for a given Evaluation. +class UnknownsUtility { + public: + UnknownsUtility(const std::vector* unknown_patterns, + google::protobuf::Arena* arena) + : unknown_patterns_(unknown_patterns), arena_(arena) {} + + // Checks whether particular corresponds to any patterns that define unknowns. + bool CheckForUnknown(const AttributeTrail& trail, bool use_partial) const; + + // Creates merged UnknownAttributeSet. + // Scans over the args collection, determines if there matches to unknown + // patterns and returns the (possibly empty) collection. + UnknownAttributeSet CheckForUnknowns(absl::Span args, + bool use_partial) const; + + // Creates merged UnknownSet. + // Scans over the args collection, merges any UnknownAttributeSets found in + // it together with initial_set (if initial_set is not null). + // Returns pointer to merged set or nullptr, if there were no sets to merge. + const UnknownSet* MergeUnknowns(absl::Span args, + const UnknownSet* initial_set) const; + + // Creates merged UnknownSet. + // Merges together attributes from UnknownSets found in the args + // collection, attributes from attr that match unknown pattern + // patterns, and attributes from initial_set + // (if initial_set is not null). + // Returns pointer to merged set or nullptr, if there were no sets to merge. + const UnknownSet* MergeUnknowns(absl::Span args, + absl::Span attrs, + const UnknownSet* initial_set, + bool use_partial) const; + + private: + const std::vector* unknown_patterns_; + google::protobuf::Arena* arena_; +}; +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_EVAL_UNKNOWNS_UTILITY_H_ diff --git a/eval/eval/unknowns_utility_test.cc b/eval/eval/unknowns_utility_test.cc new file mode 100644 index 000000000..52fd546bb --- /dev/null +++ b/eval/eval/unknowns_utility_test.cc @@ -0,0 +1,145 @@ +#include "eval/eval/unknowns_utility.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +using testing::Eq; +using testing::IsNull; +using testing::NotNull; +using testing::SizeIs; +using testing::UnorderedPointwise; + +TEST(UnknownsUtilityTest, UnknownsUtilityCheckUnknowns) { + google::protobuf::Arena arena; + std::vector patterns = { + CelAttributePattern("unknown0", {CelAttributeQualifierPattern::Create( + CelValue::CreateInt64(1))}), + CelAttributePattern("unknown0", {CelAttributeQualifierPattern::Create( + CelValue::CreateInt64(2))}), + CelAttributePattern("unknown1", {}), + CelAttributePattern("unknown2", {}), + }; + UnknownsUtility utility(&patterns, &arena); + // no match for void trail + ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), true)); + ASSERT_FALSE(utility.CheckForUnknown(AttributeTrail(), false)); + + google::api::expr::v1alpha1::Expr unknown_expr0; + unknown_expr0.mutable_ident_expr()->set_name("unknown0"); + + AttributeTrail unknown_trail0(unknown_expr0, &arena); + + { ASSERT_FALSE(utility.CheckForUnknown(unknown_trail0, false)); } + + { ASSERT_TRUE(utility.CheckForUnknown(unknown_trail0, true)); } + + { + ASSERT_TRUE(utility.CheckForUnknown( + unknown_trail0.Step( + CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + false)); + } + + { + ASSERT_TRUE(utility.CheckForUnknown( + unknown_trail0.Step( + CelAttributeQualifier::Create(CelValue::CreateInt64(1)), &arena), + true)); + } +} + +TEST(UnknownsUtilityTest, UnknownsUtilityMergeUnknownsFromValues) { + google::protobuf::Arena arena; + + google::api::expr::v1alpha1::Expr unknown_expr0; + unknown_expr0.mutable_ident_expr()->set_name("unknown0"); + + google::api::expr::v1alpha1::Expr unknown_expr1; + unknown_expr1.mutable_ident_expr()->set_name("unknown1"); + + google::api::expr::v1alpha1::Expr unknown_expr2; + unknown_expr2.mutable_ident_expr()->set_name("unknown2"); + + std::vector patterns; + + CelAttribute attribute0(unknown_expr0, {}); + CelAttribute attribute1(unknown_expr1, {}); + CelAttribute attribute2(unknown_expr2, {}); + + UnknownsUtility utility(&patterns, &arena); + + UnknownSet unknown_set0(UnknownAttributeSet({&attribute0})); + UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); + UnknownSet unknown_set2(UnknownAttributeSet({&attribute1, &attribute2})); + std::vector values = { + CelValue::CreateUnknownSet(&unknown_set0), + CelValue::CreateUnknownSet(&unknown_set1), + CelValue::CreateBool(true), + CelValue::CreateInt64(1), + }; + + const UnknownSet* unknown_set = utility.MergeUnknowns(values, nullptr); + ASSERT_THAT(unknown_set, NotNull()); + ASSERT_THAT(unknown_set->unknown_attributes().attributes(), + UnorderedPointwise(Eq(), std::vector{ + &attribute0, &attribute1})); + + unknown_set = utility.MergeUnknowns(values, &unknown_set2); + ASSERT_THAT(unknown_set, NotNull()); + ASSERT_THAT( + unknown_set->unknown_attributes().attributes(), + UnorderedPointwise(Eq(), std::vector{ + &attribute0, &attribute1, &attribute2})); +} + +TEST(UnknownsUtilityTest, UnknownsUtilityCheckForUnknownsFromAttributes) { + google::protobuf::Arena arena; + + std::vector patterns = { + CelAttributePattern("unknown0", + {CelAttributeQualifierPattern::CreateWildcard()}), + }; + + google::api::expr::v1alpha1::Expr unknown_expr0; + unknown_expr0.mutable_ident_expr()->set_name("unknown0"); + + google::api::expr::v1alpha1::Expr unknown_expr1; + unknown_expr1.mutable_ident_expr()->set_name("unknown1"); + + AttributeTrail trail0(unknown_expr0, &arena); + AttributeTrail trail1(unknown_expr1, &arena); + + CelAttribute attribute1(unknown_expr1, {}); + UnknownSet unknown_set1(UnknownAttributeSet({&attribute1})); + + UnknownsUtility utility(&patterns, &arena); + + UnknownSet unknown_attr_set(utility.CheckForUnknowns( + { + AttributeTrail(), // To make sure we handle empty trail gracefully. + trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(1)), + &arena), + trail0.Step(CelAttributeQualifier::Create(CelValue::CreateInt64(2)), + &arena), + }, + false)); + + UnknownSet unknown_set(unknown_set1, unknown_attr_set); + + ASSERT_THAT(unknown_set.unknown_attributes().attributes(), SizeIs(3)); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/BUILD b/eval/public/BUILD index 878c2712e..c01383f73 100644 --- a/eval/public/BUILD +++ b/eval/public/BUILD @@ -26,9 +26,10 @@ cc_library( copts = ["-std=c++14"], deps = [ ":cel_value_internal", - "//base:status", + "//base:statusor", "//internal:proto_util", "@com_google_absl//absl/container:node_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/types:optional", @@ -47,7 +48,7 @@ cc_library( deps = [ ":cel_value", ":cel_value_internal", - "//base:status", + "//base:statusor", "//internal:proto_util", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", @@ -77,9 +78,7 @@ cc_library( copts = ["-std=c++14"], deps = [ ":cel_attribute", - "//base:status", "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/strings", ], ) @@ -93,11 +92,12 @@ cc_library( ], copts = ["-std=c++14"], deps = [ + ":cel_attribute", ":cel_function", ":cel_value", ":cel_value_producer", - "//base:status", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -148,7 +148,8 @@ cc_library( deps = [ ":cel_function", ":cel_function_registry", - "//base:status", + "//base:statusor", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -194,7 +195,7 @@ cc_library( ":cel_function_adapter", ":cel_function_registry", ":cel_options", - "//base:status", + "//base:status_macros", "//eval/eval:container_backed_list_impl", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -328,7 +329,7 @@ cc_library( copts = ["-std=c++14"], deps = [ ":cel_value", - "//base:status", + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", ], @@ -359,6 +360,7 @@ cc_test( deps = [ ":cel_value", ":unknown_attribute_set", + ":unknown_set", "//eval/testutil:test_message_cc_proto", "//testutil:util", "@com_github_google_googletest//:gtest_main", @@ -392,6 +394,7 @@ cc_test( deps = [ ":activation", ":cel_function", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -418,6 +421,7 @@ cc_test( deps = [ ":activation", ":activation_bind_helper", + "//base:status_macros", "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", "@com_google_protobuf//:protobuf", @@ -433,6 +437,7 @@ cc_test( deps = [ ":cel_function", ":cel_function_provider", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -447,6 +452,7 @@ cc_test( ":cel_function", ":cel_function_provider", ":cel_function_registry", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -462,6 +468,7 @@ cc_test( ":cel_function", ":cel_function_adapter", ":cel_value", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", ], ) @@ -479,6 +486,7 @@ cc_test( ":cel_builtins", ":cel_expr_builder_factory", ":cel_function_registry", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", @@ -498,6 +506,7 @@ cc_test( ":cel_function_registry", ":cel_value", ":extension_func_registrar", + "//base:status_macros", "@com_github_google_googletest//:gtest_main", "@com_google_absl//absl/strings", "@com_google_protobuf//:protobuf", @@ -530,8 +539,6 @@ cc_test( ":cel_value", ":unknown_attribute_set", "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", - "@com_google_protobuf//:protobuf", ], ) @@ -545,7 +552,7 @@ cc_test( deps = [ ":cel_value", ":value_export_util", - "//base:status", + "//base:status_macros", "//eval/eval:container_backed_list_impl", "//eval/eval:container_backed_map_impl", "//eval/testutil:test_message_cc_proto", @@ -555,3 +562,61 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "unknown_function_result_set", + srcs = ["unknown_function_result_set.cc"], + hdrs = ["unknown_function_result_set.h"], + copts = ["-std=c++14"], + deps = [ + ":cel_function", + ":cel_options", + ":cel_value", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + ], +) + +cc_test( + name = "unknown_function_result_set_test", + size = "small", + srcs = [ + "unknown_function_result_set_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + ":cel_function", + ":cel_value", + ":unknown_function_result_set", + "//eval/eval:container_backed_list_impl", + "//eval/eval:container_backed_map_impl", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_library( + name = "unknown_set", + hdrs = ["unknown_set.h"], + copts = ["-std=c++14"], + deps = [ + ":unknown_attribute_set", + ":unknown_function_result_set", + ], +) + +cc_test( + name = "unknown_set_test", + srcs = ["unknown_set_test.cc"], + copts = ["-std=c++14"], + deps = [ + ":cel_attribute", + ":unknown_attribute_set", + ":unknown_function_result_set", + ":unknown_set", + "@com_github_google_googletest//:gtest_main", + "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/eval/public/activation.cc b/eval/public/activation.cc index 3047cc4bd..ecd95ee13 100644 --- a/eval/public/activation.cc +++ b/eval/public/activation.cc @@ -4,10 +4,9 @@ #include #include +#include "absl/status/status.h" #include "absl/strings/string_view.h" #include "eval/public/cel_function.h" -#include "base/canonical_errors.h" -#include "base/status.h" namespace google { namespace api { @@ -26,16 +25,16 @@ absl::optional Activation::FindValue(absl::string_view name, return entry->second.RetrieveValue(arena); } -cel_base::Status Activation::InsertFunction(std::unique_ptr function) { +absl::Status Activation::InsertFunction(std::unique_ptr function) { auto& overloads = function_map_[function->descriptor().name()]; for (const auto& overload : overloads) { if (overload->descriptor().ShapeMatches(function->descriptor())) { - return cel_base::InvalidArgumentError( + return absl::InvalidArgumentError( "Function with same shape already defined in activation"); } } overloads.emplace_back(std::move(function)); - return cel_base::OkStatus(); + return absl::OkStatus(); } std::vector Activation::FindFunctionOverloads( diff --git a/eval/public/activation.h b/eval/public/activation.h index 7d38f0ba8..f6b200ec6 100644 --- a/eval/public/activation.h +++ b/eval/public/activation.h @@ -3,11 +3,13 @@ #include #include +#include #include "google/protobuf/field_mask.pb.h" #include "google/protobuf/util/field_mask_util.h" #include "absl/container/flat_hash_map.h" #include "absl/strings/string_view.h" +#include "eval/public/cel_attribute.h" #include "eval/public/cel_function.h" #include "eval/public/cel_value.h" #include "eval/public/cel_value_producer.h" @@ -39,6 +41,14 @@ class BaseActivation { // Check whether a select path is unknown. virtual bool IsPathUnknown(absl::string_view) const = 0; + // Return FieldMask defining the list of unknown paths. + virtual const google::protobuf::FieldMask unknown_paths() const = 0; + + // Return the collection of attribute patterns that determine "unknown" + // values. + virtual const std::vector& unknown_attribute_patterns() + const = 0; + virtual ~BaseActivation() {} }; @@ -68,7 +78,7 @@ class Activation : public BaseActivation { // Insert a function into the activation (ie a lazily bound function). Returns // a status if the name and shape of the function matches another one that has // already been bound. - cel_base::Status InsertFunction(std::unique_ptr function); + absl::Status InsertFunction(std::unique_ptr function); // Insert value into Activation. void InsertValue(absl::string_view name, const CelValue& value); @@ -98,10 +108,24 @@ class Activation : public BaseActivation { } // Return FieldMask defining the list of unknown paths. - const google::protobuf::FieldMask unknown_paths() const { + const google::protobuf::FieldMask unknown_paths() const override { return unknown_paths_; } + // Sets the collection of attribute patterns that will be recognized as + // "unknown" values during expression evaluation. + void set_unknown_attribute_patterns( + std::vector unknown_attribute_patterns) { + unknown_attribute_patterns_ = std::move(unknown_attribute_patterns); + } + + // Return the collection of attribute patterns that determine "unknown" + // values. + const std::vector& unknown_attribute_patterns() + const override { + return unknown_attribute_patterns_; + } + private: class ValueEntry { public: @@ -140,7 +164,9 @@ class Activation : public BaseActivation { absl::flat_hash_map>> function_map_; + // TODO(issues/41) deprecate when unknowns support is done. google::protobuf::FieldMask unknown_paths_; + std::vector unknown_attribute_patterns_; }; } // namespace runtime diff --git a/eval/public/activation_bind_helper.cc b/eval/public/activation_bind_helper.cc index 0b06272bf..532acc08a 100644 --- a/eval/public/activation_bind_helper.cc +++ b/eval/public/activation_bind_helper.cc @@ -16,17 +16,17 @@ using google::protobuf::Message; using google::protobuf::FieldDescriptor; using google::protobuf::Descriptor; -cel_base::Status CreateValueFromField(const google::protobuf::Message* msg, +absl::Status CreateValueFromField(const google::protobuf::Message* msg, const FieldDescriptor* field_desc, google::protobuf::Arena* arena, CelValue* result) { if (field_desc->is_map()) { *result = CelValue::CreateMap(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } else if (field_desc->is_repeated()) { *result = CelValue::CreateList(google::protobuf::Arena::Create( arena, msg, field_desc, arena)); - return cel_base::OkStatus(); + return absl::OkStatus(); } else { return CreateValueFromSingleField(msg, field_desc, arena, result); } @@ -34,8 +34,8 @@ cel_base::Status CreateValueFromField(const google::protobuf::Message* msg, } // namespace -::cel_base::Status BindProtoToActivation(const Message* message, Arena* arena, - Activation* activation) { +absl::Status BindProtoToActivation(const Message* message, Arena* arena, + Activation* activation) { // TODO(issues/24): Improve the utilities to bind dynamic values as well. const Descriptor* desc = message->GetDescriptor(); const google::protobuf::Reflection* reflection = message->GetReflection(); @@ -56,7 +56,7 @@ ::cel_base::Status BindProtoToActivation(const Message* message, Arena* arena, activation->InsertValue(field_desc->name(), value); } - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/activation_bind_helper.h b/eval/public/activation_bind_helper.h index 40a9c3351..a45ac9833 100644 --- a/eval/public/activation_bind_helper.h +++ b/eval/public/activation_bind_helper.h @@ -31,9 +31,9 @@ namespace runtime { // "name", with string value of "John Doe" // "age", with int value of 42. // -::cel_base::Status BindProtoToActivation(const google::protobuf::Message* message, - google::protobuf::Arena* arena, - Activation* activation); +absl::Status BindProtoToActivation(const google::protobuf::Message* message, + google::protobuf::Arena* arena, + Activation* activation); } // namespace runtime } // namespace expr diff --git a/eval/public/activation_bind_helper_test.cc b/eval/public/activation_bind_helper_test.cc index 7c00dcee1..d549d77a4 100644 --- a/eval/public/activation_bind_helper_test.cc +++ b/eval/public/activation_bind_helper_test.cc @@ -1,10 +1,10 @@ #include "eval/public/activation_bind_helper.h" -#include "eval/public/activation.h" - -#include "eval/testutil/test_message.pb.h" #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "eval/public/activation.h" +#include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -21,7 +21,7 @@ TEST(ActivationBindHelperTest, TestSingleBoolBind) { Activation activation; - ASSERT_TRUE(BindProtoToActivation(&message, &arena, &activation).ok()); + ASSERT_OK(BindProtoToActivation(&message, &arena, &activation)); auto result = activation.FindValue("bool_value", &arena); @@ -41,7 +41,7 @@ TEST(ActivationBindHelperTest, TestSingleInt32Bind) { Activation activation; - ASSERT_TRUE(BindProtoToActivation(&message, &arena, &activation).ok()); + ASSERT_OK(BindProtoToActivation(&message, &arena, &activation)); auto result = activation.FindValue("int32_value", &arena); diff --git a/eval/public/activation_test.cc b/eval/public/activation_test.cc index 6094f5b38..4927cf353 100644 --- a/eval/public/activation_test.cc +++ b/eval/public/activation_test.cc @@ -3,6 +3,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "eval/public/cel_function.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -33,10 +34,10 @@ class ConstCelFunction : public CelFunction { explicit ConstCelFunction(const CelFunctionDescriptor& desc) : CelFunction(desc) {} - cel_base::Status Evaluate(absl::Span args, CelValue* output, + absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { *output = CelValue::CreateInt64(42); - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -110,7 +111,7 @@ TEST(ActivationTest, CheckInsertFunction) { Activation activation; auto insert_status = activation.InsertFunction( std::make_unique("ConstFunc")); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT(overloads, @@ -118,7 +119,7 @@ TEST(ActivationTest, CheckInsertFunction) { &CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("ConstFunc"))))); - cel_base::Status status = activation.InsertFunction( + absl::Status status = activation.InsertFunction( std::make_unique("ConstFunc")); EXPECT_THAT(std::string(status.message()), @@ -134,10 +135,10 @@ TEST(ActivationTest, CheckRemoveFunction) { auto insert_status = activation.InsertFunction(std::make_unique( CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kInt64}})); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); insert_status = activation.InsertFunction(std::make_unique( CelFunctionDescriptor{"ConstFunc", false, {CelValue::Type::kUint64}})); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); auto overloads = activation.FindFunctionOverloads("ConstFunc"); EXPECT_THAT( diff --git a/eval/public/builtin_func_registrar.cc b/eval/public/builtin_func_registrar.cc index 2a723f933..91578b2fe 100644 --- a/eval/public/builtin_func_registrar.cc +++ b/eval/public/builtin_func_registrar.cc @@ -4,13 +4,13 @@ #include "google/protobuf/util/time_util.h" #include "absl/strings/match.h" +#include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" #include "eval/eval/container_backed_list_impl.h" #include "eval/public/cel_builtins.h" #include "eval/public/cel_function_adapter.h" #include "eval/public/cel_function_registry.h" #include "re2/re2.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -261,9 +261,9 @@ CelValue Equal(Arena* arena, CelValue t1, CelValue t2) { // // Registers all equality functions for template parameters type. template -::cel_base::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registry) { +absl::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registry) { // Inequality - ::cel_base::Status status = + absl::Status status = FunctionAdapter::CreateAndRegister( builtin::kInequal, false, Inequal, registry); if (!status.ok()) return status; @@ -276,9 +276,8 @@ ::cel_base::Status RegisterEqualityFunctionsForType(CelFunctionRegistry* registr // Registers all comparison functions for template parameter type. template -::cel_base::Status RegisterComparisonFunctionsForType( - CelFunctionRegistry* registry) { - ::cel_base::Status status = RegisterEqualityFunctionsForType(registry); +absl::Status RegisterComparisonFunctionsForType(CelFunctionRegistry* registry) { + absl::Status status = RegisterEqualityFunctionsForType(registry); if (!status.ok()) return status; // Less than @@ -301,7 +300,7 @@ ::cel_base::Status RegisterComparisonFunctionsForType( builtin::kGreaterOrEqual, false, GreaterThanOrEqual, registry); if (!status.ok()) return status; - return ::cel_base::OkStatus(); + return absl::OkStatus(); } // Template functions providing arithmetic operations @@ -382,9 +381,8 @@ CelValue Modulo(Arena* arena, uint64_t value, uint64_t value2) { // Helper method // Registers all arithmetic functions for template parameter type. template -::cel_base::Status RegisterArithmeticFunctionsForType( - CelFunctionRegistry* registry) { - cel_base::Status status = FunctionAdapter::CreateAndRegister( +absl::Status RegisterArithmeticFunctionsForType(CelFunctionRegistry* registry) { + absl::Status status = FunctionAdapter::CreateAndRegister( builtin::kAdd, false, Add, registry); if (!status.ok()) return status; @@ -488,19 +486,19 @@ const CelList* ConcatList(Arena* arena, const CelList* value1, } // Timestamp -const cel_base::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, +const absl::Status FindTimeBreakdown(absl::Time timestamp, absl::string_view tz, absl::TimeZone::CivilInfo* breakdown) { absl::TimeZone time_zone; if (!tz.empty()) { bool found = absl::LoadTimeZone(std::string(tz), &time_zone); if (!found) { - return cel_base::InvalidArgumentError("Invalid timezone"); + return absl::InvalidArgumentError("Invalid timezone"); } } *breakdown = time_zone.At(timestamp); - return cel_base::OkStatus(); + return absl::OkStatus(); } CelValue GetTimeBreakdownPart( @@ -523,7 +521,7 @@ CelValue CreateTimestampFromString(Arena* arena, if (!absl::ParseTime(absl::RFC3339_full, std::string(time_str.value()), &ts, nullptr)) { return CreateErrorValue(arena, "String to Timestamp conversion failed", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } return CelValue::CreateTimestamp(ts); } @@ -570,7 +568,7 @@ CelValue GetDayOfWeek(Arena* arena, absl::Time timestamp, absl::string_view tz) { return GetTimeBreakdownPart( arena, timestamp, tz, [](const absl::TimeZone::CivilInfo& breakdown) { - absl::Weekday weekday = absl::GetWeekday(absl::CivilDay(breakdown.cs)); + absl::Weekday weekday = absl::GetWeekday(breakdown.cs); // get day of week from the date in UTC, zero-based, zero for Sunday, // based on GetDayOfWeek CEL function definition. @@ -615,7 +613,7 @@ CelValue CreateDurationFromString(Arena* arena, absl::Duration d; if (!absl::ParseDuration(std::string(dur_str.value()), &d)) { return CreateErrorValue(arena, "String to Duration conversion failed", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } return CelValue::CreateDuration(d); @@ -654,28 +652,9 @@ bool StringStartsWith(Arena*, CelValue::StringHolder value, return absl::StartsWith(value.value(), prefix.value()); } -} // namespace - -::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, - const InterpreterOptions& options) { - // logical NOT - cel_base::Status status = FunctionAdapter::CreateAndRegister( - builtin::kNot, false, [](Arena*, bool value) -> bool { return !value; }, - registry); - if (!status.ok()) return status; - - // Negation group - status = FunctionAdapter::CreateAndRegister( - builtin::kNeg, false, [](Arena*, int64_t value) -> int64_t { return -value; }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kNeg, false, - [](Arena*, double value) -> double { return -value; }, registry); - if (!status.ok()) return status; - - status = RegisterComparisonFunctionsForType(registry); +absl::Status RegisterComparisonFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + auto status = RegisterComparisonFunctionsForType(registry); if (!status.ok()) return status; status = RegisterComparisonFunctionsForType(registry); @@ -708,6 +687,247 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, status = RegisterEqualityFunctionsForType(registry); if (!status.ok()) return status; + return absl::OkStatus(); +} + +absl::Status RegisterStringFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + auto status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringContains, false, StringContains, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringContains, true, StringContains, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringEndsWith, false, StringEndsWith, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringEndsWith, true, StringEndsWith, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringStartsWith, false, StringStartsWith, + registry); + if (!status.ok()) return status; + + status = + FunctionAdapter:: + CreateAndRegister(builtin::kStringStartsWith, true, StringStartsWith, + registry); + if (!status.ok()) return status; + + return absl::OkStatus(); +} + +absl::Status RegisterTimestampFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + // Timestamp + // + // timestamp() conversion from string.. + auto status = + FunctionAdapter::CreateAndRegister( + builtin::kTimestamp, false, CreateTimestampFromString, registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kFullYear, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetFullYear(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kFullYear, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetFullYear(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kMonth, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMonth(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kMonth, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetMonth(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDayOfYear, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfYear(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDayOfYear, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDayOfYear(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDayOfMonth, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfMonth(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDayOfMonth, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDayOfMonth(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDate, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDate(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDate, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDate(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kDayOfWeek, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetDayOfWeek(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kDayOfWeek, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetDayOfWeek(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kHours, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetHours(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kHours, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetHours(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kMinutes, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMinutes(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kMinutes, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetMinutes(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kSeconds, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetSeconds(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kSeconds, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetSeconds(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter:: + CreateAndRegister( + builtin::kMilliseconds, true, + [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) + -> CelValue { return GetMilliseconds(arena, ts, tz.value()); }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kMilliseconds, true, + [](Arena* arena, absl::Time ts) -> CelValue { + return GetMilliseconds(arena, ts, ""); + }, + registry); + if (!status.ok()) return status; + + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, + const InterpreterOptions& options) { + // logical NOT + absl::Status status = FunctionAdapter::CreateAndRegister( + builtin::kNot, false, [](Arena*, bool value) -> bool { return !value; }, + registry); + if (!status.ok()) return status; + + // Negation group + status = FunctionAdapter::CreateAndRegister( + builtin::kNeg, false, [](Arena*, int64_t value) -> int64_t { return -value; }, + registry); + if (!status.ok()) return status; + + status = FunctionAdapter::CreateAndRegister( + builtin::kNeg, false, + [](Arena*, double value) -> double { return -value; }, registry); + if (!status.ok()) return status; + + status = RegisterComparisonFunctions(registry, options); + if (!status.ok()) return status; + // Logical AND // This implementation is used when short-circuiting is off. status = FunctionAdapter::CreateAndRegister( @@ -1128,11 +1348,11 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, RE2 re2(regex.value().data()); if (max_size > 0 && re2.ProgramSize() > max_size) { return CreateErrorValue(arena, "exceeded RE2 max program size", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } if (!re2.ok()) { return CreateErrorValue(arena, "invalid_argument", - cel_base::StatusCode::kInvalidArgument); + absl::StatusCode::kInvalidArgument); } return CelValue::CreateBool(RE2::PartialMatch(re2::StringPiece(target.value().data(), target.value().size()), re2)); }; @@ -1151,40 +1371,7 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; } - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, false, StringContains, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringContains, true, StringContains, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, false, StringEndsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringEndsWith, true, StringEndsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, false, StringStartsWith, - registry); - if (!status.ok()) return status; - - status = - FunctionAdapter:: - CreateAndRegister(builtin::kStringStartsWith, true, StringStartsWith, - registry); + status = RegisterStringFunctions(registry, options); if (!status.ok()) return status; // Modulo @@ -1196,171 +1383,7 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, builtin::kModulo, false, Modulo, registry); if (!status.ok()) return status; - // Timestamp - // - // timestamp() conversion from string.. - status = FunctionAdapter::CreateAndRegister( - builtin::kTimestamp, false, CreateTimestampFromString, registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kFullYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetFullYear(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kFullYear, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetFullYear(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMonth(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMonth, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMonth(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfYear, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfYear(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfYear, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfYear(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfMonth, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfMonth(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfMonth, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfMonth(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDate, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDate(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDate, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDate(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kDayOfWeek, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetDayOfWeek(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kDayOfWeek, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetDayOfWeek(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetHours(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kHours, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetHours(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMinutes(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMinutes, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMinutes(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetSeconds(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kSeconds, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetSeconds(arena, ts, ""); - }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter:: - CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Time ts, CelValue::StringHolder tz) - -> CelValue { return GetMilliseconds(arena, ts, tz.value()); }, - registry); - if (!status.ok()) return status; - - status = FunctionAdapter::CreateAndRegister( - builtin::kMilliseconds, true, - [](Arena* arena, absl::Time ts) -> CelValue { - return GetMilliseconds(arena, ts, ""); - }, - registry); + status = RegisterTimestampFunctions(registry, options); if (!status.ok()) return status; // type conversion to int @@ -1385,6 +1408,20 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, registry); if (!status.ok()) return status; + status = FunctionAdapter::CreateAndRegister( + builtin::kInt, false, + [](Arena* arena, CelValue::StringHolder s) { + int64_t result; + if (absl::SimpleAtoi(s.value(), &result)) { + return CelValue::CreateInt64(result); + } else { + return CreateErrorValue(arena, "doesn't convert to a string", + absl::StatusCode::kInvalidArgument); + } + }, + registry); + if (!status.ok()) return status; + // duration // duration() conversion from string.. @@ -1473,7 +1510,7 @@ ::cel_base::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry, if (!status.ok()) return status; } - return ::cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/builtin_func_registrar.h b/eval/public/builtin_func_registrar.h index 332487d87..2cf906857 100644 --- a/eval/public/builtin_func_registrar.h +++ b/eval/public/builtin_func_registrar.h @@ -10,7 +10,7 @@ namespace api { namespace expr { namespace runtime { -cel_base::Status RegisterBuiltinFunctions( +absl::Status RegisterBuiltinFunctions( CelFunctionRegistry* registry, const InterpreterOptions& options = InterpreterOptions()); diff --git a/eval/public/builtin_func_test.cc b/eval/public/builtin_func_test.cc index 183a7840c..497ede135 100644 --- a/eval/public/builtin_func_test.cc +++ b/eval/public/builtin_func_test.cc @@ -8,6 +8,7 @@ #include "eval/public/cel_builtins.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_function_registry.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -31,7 +32,7 @@ class BuiltinsTest : public ::testing::Test { protected: BuiltinsTest() {} - void SetUp() override { ASSERT_TRUE(RegisterBuiltinFunctions(®istry_).ok()); } + void SetUp() override { ASSERT_OK(RegisterBuiltinFunctions(®istry_)); } // Helper method. Looks up in registry and tests comparison operation. void PerformRun(absl::string_view operation, absl::optional target, @@ -68,18 +69,18 @@ class BuiltinsTest : public ::testing::Test { CreateCelExpressionBuilder(options); // Builtin registration. - ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry(), options).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry(), options)); // Create CelExpression from AST (Expr object). auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - ASSERT_TRUE(cel_expression_status.ok()); + ASSERT_OK(cel_expression_status); auto cel_expression = std::move(cel_expression_status.ValueOrDie()); auto eval_status = cel_expression->Evaluate(activation, &arena_); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); *result = eval_status.ValueOrDie(); } @@ -895,7 +896,7 @@ TEST_F(BuiltinsTest, MapInt64Index) { ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), - Eq(cel_base::StatusCode::kNotFound)); + Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } @@ -924,7 +925,7 @@ TEST_F(BuiltinsTest, MapUint64Index) { ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), - Eq(cel_base::StatusCode::kNotFound)); + Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } @@ -955,7 +956,7 @@ TEST_F(BuiltinsTest, MapStringIndex) { ASSERT_TRUE(result_value.IsError()); EXPECT_THAT(result_value.ErrorOrDie()->code(), - Eq(cel_base::StatusCode::kNotFound)); + Eq(absl::StatusCode::kNotFound)); EXPECT_TRUE(CheckNoSuchKeyError(result_value)); } @@ -1368,6 +1369,23 @@ TEST_F(BuiltinsTest, MatchesMaxSize) { EXPECT_TRUE(result_value.IsError()); } +TEST_F(BuiltinsTest, StringToInt) { + std::string target = "-42"; + std::vector args = {CelValue::CreateString(&target)}; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInt, {}, args, &result_value)); + ASSERT_TRUE(result_value.IsInt64()); + EXPECT_EQ(result_value.Int64OrDie(), -42); +} + +TEST_F(BuiltinsTest, StringToIntNonInt) { + std::string target = "not_a_number"; + std::vector args = {CelValue::CreateString(&target)}; + CelValue result_value; + ASSERT_NO_FATAL_FAILURE(PerformRun(builtin::kInt, {}, args, &result_value)); + ASSERT_TRUE(result_value.IsError()); +} + TEST_F(BuiltinsTest, IntToString) { std::vector args = {CelValue::CreateInt64(-42)}; CelValue result_value; diff --git a/eval/public/cel_expr_builder_factory.cc b/eval/public/cel_expr_builder_factory.cc index 21be23507..715045c6b 100644 --- a/eval/public/cel_expr_builder_factory.cc +++ b/eval/public/cel_expr_builder_factory.cc @@ -1,5 +1,7 @@ #include "eval/public/cel_expr_builder_factory.h" + #include "eval/compiler/flat_expr_builder.h" +#include "eval/public/cel_options.h" namespace google { namespace api { @@ -15,6 +17,19 @@ std::unique_ptr CreateCelExpressionBuilder( builder->set_enable_comprehension(options.enable_comprehension); builder->set_comprehension_max_iterations( options.comprehension_max_iterations); + builder->set_fail_on_warnings(options.fail_on_warnings); + + switch (options.unknown_processing) { + case UnknownProcessingOptions::kAttributeAndFunction: + builder->set_enable_unknown_function_results(true); + builder->set_enable_unknowns(true); + break; + case UnknownProcessingOptions::kAttributeOnly: + builder->set_enable_unknowns(true); + break; + case UnknownProcessingOptions::kDisabled: + break; + } return std::move(builder); } diff --git a/eval/public/cel_expression.h b/eval/public/cel_expression.h index 74815369e..403c3f39a 100644 --- a/eval/public/cel_expression.h +++ b/eval/public/cel_expression.h @@ -22,13 +22,23 @@ namespace runtime { // then the order of the callback invocations is guaranteed to correspond // the order of variable sub-elements (e.g. the order of elements returned // by Comprehension.iter_range). -using CelEvaluationListener = std::function; +// An opaque state used for evaluation of a cell expression. +class CelEvaluationState { + public: + virtual ~CelEvaluationState() = default; +}; + // Base interface for expression evaluating objects. class CelExpression { public: - virtual ~CelExpression() {} + virtual ~CelExpression() = default; + + // Initializes the state + virtual std::unique_ptr InitializeState( + google::protobuf::Arena* arena) const = 0; // Evaluates expression and returns value. // activation contains bindings from parameter names to values @@ -37,10 +47,24 @@ class CelExpression { virtual cel_base::StatusOr Evaluate(const BaseActivation& activation, google::protobuf::Arena* arena) const = 0; + // Evaluates expression and returns value. + // activation contains bindings from parameter names to values + // state must be non-null and created prior to calling Evaluate by + // InitializeState. + virtual cel_base::StatusOr Evaluate( + const BaseActivation& activation, CelEvaluationState* state) const = 0; + // Trace evaluates expression calling the callback on each sub-tree. virtual cel_base::StatusOr Trace( const BaseActivation& activation, google::protobuf::Arena* arena, CelEvaluationListener callback) const = 0; + + // Trace evaluates expression calling the callback on each sub-tree. + // state must be non-null and created prior to calling Evaluate by + // InitializeState. + virtual cel_base::StatusOr Trace( + const BaseActivation& activation, CelEvaluationState* state, + CelEvaluationListener callback) const = 0; }; // Base class for Expression Builder implementations @@ -60,6 +84,14 @@ class CelExpressionBuilder { const google::api::expr::v1alpha1::Expr* expr, const google::api::expr::v1alpha1::SourceInfo* source_info) const = 0; + // Creates CelExpression object from AST tree. + // expr specifies root of AST tree. + // non-fatal build warnings are written to warnings if encountered. + virtual cel_base::StatusOr> CreateExpression( + const google::api::expr::v1alpha1::Expr* expr, + const google::api::expr::v1alpha1::SourceInfo* source_info, + std::vector* warnings) const = 0; + // CelFunction registry. Extension function should be registered with it // prior to expression creation. CelFunctionRegistry* GetRegistry() const { return registry_.get(); } @@ -70,12 +102,12 @@ class CelExpressionBuilder { } // Add Enum to the list of resolvable by the builder. - void addResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { + void AddResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { resolvable_enums_.emplace(enum_descriptor); } // Remove Enum from the list of resolvable by the builder. - void removeResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { + void RemoveResolvableEnum(const google::protobuf::EnumDescriptor* enum_descriptor) { resolvable_enums_.erase(enum_descriptor); } diff --git a/eval/public/cel_function.h b/eval/public/cel_function.h index ba6876c08..7e1dc0275 100644 --- a/eval/public/cel_function.h +++ b/eval/public/cel_function.h @@ -75,9 +75,9 @@ class CelFunction { // zero). When former happens, error Status is returned and *result is // not changed. In case of business logic error, returned Status is Ok, and // error is provided as CelValue - wrapped CelError in *result. - virtual ::cel_base::Status Evaluate(absl::Span arguments, - CelValue* result, - google::protobuf::Arena* arena) const = 0; + virtual absl::Status Evaluate(absl::Span arguments, + CelValue* result, + google::protobuf::Arena* arena) const = 0; // Determines whether instance of CelFunction is applicable to // arguments supplied. diff --git a/eval/public/cel_function_adapter.h b/eval/public/cel_function_adapter.h index ec4ad2d06..855af377b 100644 --- a/eval/public/cel_function_adapter.h +++ b/eval/public/cel_function_adapter.h @@ -4,10 +4,10 @@ #include "google/protobuf/duration.pb.h" #include "google/protobuf/timestamp.pb.h" +#include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_registry.h" -#include "base/status.h" #include "base/statusor.h" namespace google { @@ -91,8 +91,8 @@ class FunctionAdapter : public CelFunction { std::vector arg_types; if (!internal::AddType<0, Arguments...>(&arg_types)) { - return cel_base::Status( - cel_base::StatusCode::kInternal, + return absl::Status( + absl::StatusCode::kInternal, absl::StrCat("Failed to create adapter for ", name, ": failed to determine input parameter type")); } @@ -105,7 +105,7 @@ class FunctionAdapter : public CelFunction { // Creates function handler and attempts to register it with // supplied function registry. - static cel_base::Status CreateAndRegister( + static absl::Status CreateAndRegister( absl::string_view name, bool receiver_type, std::function handler, CelFunctionRegistry* registry) { @@ -119,26 +119,26 @@ class FunctionAdapter : public CelFunction { #if defined(__clang_major_version__) && __clang_major_version__ >= 8 && !defined(__APPLE__) template - inline cel_base::Status RunWrap(absl::Span arguments, + inline absl::Status RunWrap(absl::Span arguments, std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, ::google::protobuf::Arena* arena) const { if (!ConvertFromValue(arguments[arg_index], &std::get(input))) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Type conversion failed"); } return RunWrap(arguments, input, result, arena); } template <> - inline cel_base::Status RunWrap( + inline absl::Status RunWrap( absl::Span, std::tuple<::google::protobuf::Arena*, Arguments...> input, CelValue* result, ::google::protobuf::Arena* arena) const { return CreateReturnValue(absl::apply(handler_, input), arena, result); } #else - inline cel_base::Status RunWrap(std::function func, + inline absl::Status RunWrap(std::function func, const absl::Span argset, ::google::protobuf::Arena* arena, CelValue* result, int arg_index) const { @@ -146,13 +146,13 @@ class FunctionAdapter : public CelFunction { } template - inline cel_base::Status RunWrap(std::function func, + inline absl::Status RunWrap(std::function func, const absl::Span argset, ::google::protobuf::Arena* arena, CelValue* result, int arg_index) const { Arg argument; if (!ConvertFromValue(argset[arg_index], &argument)) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Type conversion failed"); } @@ -166,11 +166,10 @@ class FunctionAdapter : public CelFunction { } #endif - ::cel_base::Status Evaluate(absl::Span arguments, - CelValue* result, - ::google::protobuf::Arena* arena) const override { + absl::Status Evaluate(absl::Span arguments, CelValue* result, + ::google::protobuf::Arena* arena) const override { if (arguments.size() != sizeof...(Arguments)) { - return cel_base::Status(cel_base::StatusCode::kInternal, + return absl::Status(absl::StatusCode::kInternal, "Argument number mismatch"); } @@ -201,91 +200,91 @@ class FunctionAdapter : public CelFunction { } // CreateReturnValue method wraps evaluation result with CelValue. - static cel_base::Status CreateReturnValue(bool value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(bool value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateBool(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(int64_t value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(int64_t value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateInt64(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(uint64_t value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(uint64_t value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateUint64(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(double value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(double value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateDouble(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(CelValue::StringHolder value, + static absl::Status CreateReturnValue(CelValue::StringHolder value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateString(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(CelValue::BytesHolder value, + static absl::Status CreateReturnValue(CelValue::BytesHolder value, ::google::protobuf::Arena*, CelValue* result) { *result = CelValue::CreateBytes(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const ::google::protobuf::Message* value, + static absl::Status CreateReturnValue(const ::google::protobuf::Message* value, ::google::protobuf::Arena* arena, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null Message pointer returned"); } *result = CelValue::CreateMessage(value, arena); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelList* value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelList* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelList pointer returned"); } *result = CelValue::CreateList(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelMap* value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelMap* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelMap pointer returned"); } *result = CelValue::CreateMap(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelError* value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelError* value, ::google::protobuf::Arena*, CelValue* result) { if (value == nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Null CelError pointer returned"); } *result = CelValue::CreateError(value); - return cel_base::OkStatus(); + return absl::OkStatus(); } - static cel_base::Status CreateReturnValue(const CelValue& value, ::google::protobuf::Arena*, + static absl::Status CreateReturnValue(const CelValue& value, ::google::protobuf::Arena*, CelValue* result) { *result = value; - return cel_base::OkStatus(); + return absl::OkStatus(); } template - static cel_base::Status CreateReturnValue(const cel_base::StatusOr& value, + static absl::Status CreateReturnValue(const cel_base::StatusOr& value, ::google::protobuf::Arena*, CelValue*) { if (!value) { return value.status(); diff --git a/eval/public/cel_function_adapter_test.cc b/eval/public/cel_function_adapter_test.cc index 89715c65e..8a6e1b50f 100644 --- a/eval/public/cel_function_adapter_test.cc +++ b/eval/public/cel_function_adapter_test.cc @@ -2,6 +2,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -15,7 +16,7 @@ TEST(CelFunctionAdapterTest, TestAdapterNoArg) { auto func_status = FunctionAdapter::Create("const", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -25,7 +26,7 @@ TEST(CelFunctionAdapterTest, TestAdapterNoArg) { google::protobuf::Arena arena; auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE( result.IsInt64()); // Obvious failure, for educational purposes only. @@ -37,7 +38,7 @@ TEST(CelFunctionAdapterTest, TestAdapterOneArg) { auto func_status = FunctionAdapter::Create("_++_", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -51,7 +52,7 @@ TEST(CelFunctionAdapterTest, TestAdapterOneArg) { auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 100); @@ -65,7 +66,7 @@ TEST(CelFunctionAdapterTest, TestAdapterTwoArgs) { auto func_status = FunctionAdapter::Create("_++_", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -80,7 +81,7 @@ TEST(CelFunctionAdapterTest, TestAdapterTwoArgs) { auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE(result.IsInt64()); EXPECT_EQ(result.Int64OrDie(), 42); @@ -100,7 +101,7 @@ TEST(CelFunctionAdapterTest, TestAdapterThreeArgs) { FunctionAdapter::Create("concat", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); @@ -120,7 +121,7 @@ TEST(CelFunctionAdapterTest, TestAdapterThreeArgs) { auto eval_status = cel_func->Evaluate(args, &result, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); ASSERT_TRUE(result.IsString()); EXPECT_EQ(result.StringOrDie().value(), "123"); @@ -139,7 +140,7 @@ TEST(CelFunctionAdapterTest, TestTypeDeductionForCelValueBasicTypes) { absl::Duration, absl::Time, const CelList*, const CelMap*, const CelError*>::Create("dummy_func", false, func); - ASSERT_TRUE(func_status.ok()); + ASSERT_OK(func_status); auto cel_func = std::move(func_status.ValueOrDie()); diff --git a/eval/public/cel_function_provider.cc b/eval/public/cel_function_provider.cc index 42ce62e34..a695c65b9 100644 --- a/eval/public/cel_function_provider.cc +++ b/eval/public/cel_function_provider.cc @@ -22,7 +22,7 @@ class ActivationFunctionProviderImpl : public CelFunctionProvider { for (const CelFunction* overload : overloads) { if (overload->descriptor().ShapeMatches(descriptor)) { if (matching_overload != nullptr) { - return cel_base::Status(cel_base::StatusCode::kInvalidArgument, + return absl::Status(absl::StatusCode::kInvalidArgument, "Couldn't resolve function."); } matching_overload = overload; diff --git a/eval/public/cel_function_provider_test.cc b/eval/public/cel_function_provider_test.cc index 80ffbbd9b..0f2d1ff41 100644 --- a/eval/public/cel_function_provider_test.cc +++ b/eval/public/cel_function_provider_test.cc @@ -2,6 +2,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -19,9 +20,9 @@ class ConstCelFunction : public CelFunction { ConstCelFunction() : CelFunction({"ConstFunction", false, {}}) {} explicit ConstCelFunction(const CelFunctionDescriptor& desc) : CelFunction(desc) {} - cel_base::Status Evaluate(absl::Span args, CelValue* output, + absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { - return cel_base::Status(cel_base::StatusCode::kUnimplemented, "Not Implemented"); + return absl::Status(absl::StatusCode::kUnimplemented, "Not Implemented"); } }; @@ -31,7 +32,7 @@ TEST(CreateActivationFunctionProviderTest, NoOverloadFound) { auto func = provider->GetFunction({"LazyFunc", false, {}}, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Eq(nullptr)); } @@ -42,11 +43,11 @@ TEST(CreateActivationFunctionProviderTest, OverloadFound) { auto status = activation.InsertFunction(std::make_unique(desc)); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); auto func = provider->GetFunction(desc, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Ne(nullptr)); } @@ -60,9 +61,9 @@ TEST(CreateActivationFunctionProviderTest, AmbiguousLookup) { auto status = activation.InsertFunction(std::make_unique(desc1)); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); status = activation.InsertFunction(std::make_unique(desc2)); - EXPECT_TRUE(status.ok()); + EXPECT_OK(status); auto func = provider->GetFunction(match_desc, activation); diff --git a/eval/public/cel_function_registry.cc b/eval/public/cel_function_registry.cc index 44cd8efcc..34202afe4 100644 --- a/eval/public/cel_function_registry.cc +++ b/eval/public/cel_function_registry.cc @@ -5,27 +5,27 @@ namespace api { namespace expr { namespace runtime { -cel_base::Status CelFunctionRegistry::Register( +absl::Status CelFunctionRegistry::Register( std::unique_ptr function) { const CelFunctionDescriptor& descriptor = function->descriptor(); if (DescriptorRegistered(descriptor)) { - return cel_base::Status( - cel_base::StatusCode::kAlreadyExists, + return absl::Status( + absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } auto& overloads = functions_[descriptor.name()]; overloads.static_overloads.push_back(std::move(function)); - return cel_base::OkStatus(); + return absl::OkStatus(); } -cel_base::Status CelFunctionRegistry::RegisterLazyFunction( +absl::Status CelFunctionRegistry::RegisterLazyFunction( const CelFunctionDescriptor& descriptor, std::unique_ptr factory) { if (DescriptorRegistered(descriptor)) { - return cel_base::Status( - cel_base::StatusCode::kAlreadyExists, + return absl::Status( + absl::StatusCode::kAlreadyExists, "CelFunction with specified parameters already registered"); } auto& overloads = functions_[descriptor.name()]; @@ -33,7 +33,7 @@ cel_base::Status CelFunctionRegistry::RegisterLazyFunction( descriptor, std::move(factory)); overloads.lazy_overloads.push_back(std::move(entry)); - return cel_base::OkStatus(); + return absl::OkStatus(); } std::vector CelFunctionRegistry::FindOverloads( diff --git a/eval/public/cel_function_registry.h b/eval/public/cel_function_registry.h index fbeafac4e..de1d64bcc 100644 --- a/eval/public/cel_function_registry.h +++ b/eval/public/cel_function_registry.h @@ -26,18 +26,18 @@ class CelFunctionRegistry { // passed to registry. // Function registration should be performed prior to // CelExpression creation. - cel_base::Status Register(std::unique_ptr function); + absl::Status Register(std::unique_ptr function); // Register a lazily provided function. CelFunctionProvider is used to get // a CelFunction ptr at evaluation time. The registry takes ownership of the // factory. - cel_base::Status RegisterLazyFunction( + absl::Status RegisterLazyFunction( const CelFunctionDescriptor& descriptor, std::unique_ptr factory); // Register a lazily provided function. This overload uses a default provider // that delegates to the activation at evaluation time. - cel_base::Status RegisterLazyFunction(const CelFunctionDescriptor& descriptor) { + absl::Status RegisterLazyFunction(const CelFunctionDescriptor& descriptor) { return RegisterLazyFunction(descriptor, CreateActivationFunctionProvider()); } diff --git a/eval/public/cel_function_registry_test.cc b/eval/public/cel_function_registry_test.cc index 2a9297a24..a8cc6a97f 100644 --- a/eval/public/cel_function_registry_test.cc +++ b/eval/public/cel_function_registry_test.cc @@ -6,6 +6,7 @@ #include "gtest/gtest.h" #include "eval/public/cel_function.h" #include "eval/public/cel_function_provider.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -38,11 +39,11 @@ class ConstCelFunction : public CelFunction { return {"ConstFunction", false, {}}; } - cel_base::Status Evaluate(absl::Span args, CelValue* output, + absl::Status Evaluate(absl::Span args, CelValue* output, google::protobuf::Arena* arena) const override { *output = CelValue::CreateInt64(42); - return cel_base::OkStatus(); + return absl::OkStatus(); } }; @@ -52,12 +53,12 @@ TEST(CelFunctionRegistryTest, InsertAndRetrieveLazyFunction) { Activation activation; auto register_status = registry.RegisterLazyFunction( lazy_function_desc, std::make_unique()); - EXPECT_TRUE(register_status.ok()); + EXPECT_OK(register_status); const auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(providers, testing::SizeIs(1)); auto func = providers[0]->GetFunction(lazy_function_desc, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Eq(nullptr)); } @@ -69,9 +70,9 @@ TEST(CelFunctionRegistryTest, LazyAndStaticFunctionShareDescriptorSpace) { CelFunctionDescriptor desc = ConstCelFunction::MakeDescriptor(); auto register_status = registry.RegisterLazyFunction( desc, std::make_unique()); - EXPECT_TRUE(register_status.ok()); + EXPECT_OK(register_status); - cel_base::Status status = registry.Register(std::make_unique()); + absl::Status status = registry.Register(std::make_unique()); EXPECT_FALSE(status.ok()); } @@ -81,8 +82,8 @@ TEST(CelFunctionRegistryTest, ListFunctions) { auto register_status = registry.RegisterLazyFunction( lazy_function_desc, std::make_unique()); - EXPECT_TRUE(register_status.ok()); - EXPECT_TRUE(registry.Register(std::make_unique()).ok()); + EXPECT_OK(register_status); + EXPECT_OK(registry.Register(std::make_unique())); auto registered_functions = registry.ListFunctions(); @@ -95,15 +96,15 @@ TEST(CelFunctionRegistryTest, DefaultLazyProvider) { CelFunctionDescriptor lazy_function_desc{"LazyFunction", false, {}}; CelFunctionRegistry registry; Activation activation; - EXPECT_TRUE(registry.RegisterLazyFunction(lazy_function_desc).ok()); + EXPECT_OK(registry.RegisterLazyFunction(lazy_function_desc)); auto insert_status = activation.InsertFunction( std::make_unique(lazy_function_desc)); - EXPECT_TRUE(insert_status.ok()); + EXPECT_OK(insert_status); const auto providers = registry.FindLazyOverloads("LazyFunction", false, {}); EXPECT_THAT(providers, testing::SizeIs(1)); auto func = providers[0]->GetFunction(lazy_function_desc, activation); - ASSERT_TRUE(func.status().ok()); + ASSERT_OK(func.status()); EXPECT_THAT(func.ValueOrDie(), Property(&CelFunction::descriptor, Property(&CelFunctionDescriptor::name, Eq("LazyFunction")))); diff --git a/eval/public/cel_options.h b/eval/public/cel_options.h index 46cbfb8d2..daff7f9dd 100644 --- a/eval/public/cel_options.h +++ b/eval/public/cel_options.h @@ -8,8 +8,23 @@ namespace api { namespace expr { namespace runtime { +// Options for unknown processing. +enum class UnknownProcessingOptions { + // No unknown processing. + kDisabled, + // Only attributes supported. + kAttributeOnly, + // Attributes and functions supported. Function results are dependent on the + // logic for handling unknown_attributes, so clients must opt in to both. + kAttributeAndFunction +}; + // Interpreter options for controlling evaluation and builtin functions. struct InterpreterOptions { + // Level of unknown support enabled. + UnknownProcessingOptions unknown_processing = + UnknownProcessingOptions::kDisabled; + // Enable short-circuiting of the logical operator evaluation. If enabled, // AND, OR, and TERNARY do not evaluate the entire expression once the the // resulting value is known from the left-hand side. @@ -20,6 +35,8 @@ struct InterpreterOptions { // Enable constant folding during the expression creation. If enabled, // an arena must be provided for constant generation. + // Note that expression tracing applies a modified expression if this option + // is enabled. bool constant_folding = false; google::protobuf::Arena* constant_arena = nullptr; @@ -50,6 +67,9 @@ struct InterpreterOptions { // Enable list membership overload. bool enable_list_contains = true; + + // Treat builder warnings as fatal errors. + bool fail_on_warnings = true; }; } // namespace runtime diff --git a/eval/public/cel_value.cc b/eval/public/cel_value.cc index b51e77698..000f3b2b4 100644 --- a/eval/public/cel_value.cc +++ b/eval/public/cel_value.cc @@ -4,6 +4,7 @@ #include "google/protobuf/struct.pb.h" #include "google/protobuf/wrappers.pb.h" #include "absl/container/node_hash_map.h" +#include "absl/status/status.h" #include "absl/strings/substitute.h" #include "absl/synchronization/mutex.h" #include "internal/proto_util.h" @@ -40,6 +41,8 @@ constexpr char kErrNoMatchingOverload[] = "No matching overloads found"; constexpr char kErrNoSuchKey[] = "Key not found in map"; constexpr absl::string_view kErrUnknownValue = "Unknown value "; constexpr absl::string_view kPayloadUrlUnknownPath = "unknown_path"; +constexpr absl::string_view kPayloadUrlUnknownFunctionResult = + "cel_is_unknown_function_result"; // Forward declaration for google.protobuf.Value CelValue ValueFromMessage(const Value* value, Arena* arena); @@ -368,7 +371,7 @@ std::string CelValue::TypeName(Type value_type) { case Type::kMap: return "CelMap"; case Type::kUnknownSet: - return "UnknownAttributeSet"; + return "UnknownSet"; case Type::kError: return "CelError"; default: @@ -377,14 +380,14 @@ std::string CelValue::TypeName(Type value_type) { } CelValue CreateErrorValue(Arena* arena, absl::string_view message, - cel_base::StatusCode error_code, int) { + absl::StatusCode error_code, int) { CelError* error = Arena::Create(arena, error_code, message); return CelValue::CreateError(error); } CelValue CreateNoMatchingOverloadError(google::protobuf::Arena* arena) { return CreateErrorValue(arena, kErrNoMatchingOverload, - cel_base::StatusCode::kUnknown); + absl::StatusCode::kUnknown); } bool CheckNoMatchingOverloadError(CelValue value) { @@ -393,11 +396,11 @@ bool CheckNoMatchingOverloadError(CelValue value) { } CelValue CreateNoSuchFieldError(google::protobuf::Arena* arena) { - return CreateErrorValue(arena, "no_such_field", cel_base::StatusCode::kNotFound); + return CreateErrorValue(arena, "no_such_field", absl::StatusCode::kNotFound); } CelValue CreateNoSuchKeyError(google::protobuf::Arena* arena, absl::string_view) { - return CreateErrorValue(arena, kErrNoSuchKey, cel_base::StatusCode::kNotFound); + return CreateErrorValue(arena, kErrNoSuchKey, absl::StatusCode::kNotFound); } bool CheckNoSuchKeyError(CelValue value) { @@ -407,9 +410,9 @@ bool CheckNoSuchKeyError(CelValue value) { CelValue CreateUnknownValueError(google::protobuf::Arena* arena, absl::string_view unknown_path) { CelError* error = - Arena::Create(arena, cel_base::StatusCode::kUnavailable, + Arena::Create(arena, absl::StatusCode::kUnavailable, absl::StrCat(kErrUnknownValue, unknown_path)); - error->SetPayload(kPayloadUrlUnknownPath, cel_base::StatusCord(unknown_path)); + error->SetPayload(kPayloadUrlUnknownPath, absl::Cord(unknown_path)); return CelValue::CreateError(error); } @@ -417,7 +420,7 @@ bool IsUnknownValueError(const CelValue& value) { // TODO(issues/41): replace with the implementation of go/cel-known-unknowns if (!value.IsError()) return false; const CelError* error = value.ErrorOrDie(); - if (error && error->code() == cel_base::StatusCode::kUnavailable) { + if (error && error->code() == absl::StatusCode::kUnavailable) { auto path = error->GetPayload(kPayloadUrlUnknownPath); return path.has_value(); } @@ -427,7 +430,7 @@ bool IsUnknownValueError(const CelValue& value) { std::set GetUnknownPathsSetOrDie(const CelValue& value) { // TODO(issues/41): replace with the implementation of go/cel-known-unknowns const CelError* error = value.ErrorOrDie(); - if (error && error->code() == cel_base::StatusCode::kUnavailable) { + if (error && error->code() == absl::StatusCode::kUnavailable) { auto path = error->GetPayload(kPayloadUrlUnknownPath); if (path.has_value()) return {std::string(path.value())}; } @@ -435,6 +438,27 @@ std::set GetUnknownPathsSetOrDie(const CelValue& value) { return {}; } +CelValue CreateUnknownFunctionResultError(google::protobuf::Arena* arena, + absl::string_view help_message) { + CelError* error = Arena::Create( + arena, absl::StatusCode::kUnavailable, + absl::StrCat("Unknown function result: ", help_message)); + error->SetPayload(kPayloadUrlUnknownFunctionResult, absl::Cord("true")); + return CelValue::CreateError(error); +} + +bool IsUnknownFunctionResult(const CelValue& value) { + if (!value.IsError()) { + return false; + } + const CelError* error = value.ErrorOrDie(); + if (error == nullptr || error->code() != absl::StatusCode::kUnavailable) { + return false; + } + auto payload = error->GetPayload(kPayloadUrlUnknownFunctionResult); + return payload.has_value() && payload.value() == "true"; +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/cel_value.h b/eval/public/cel_value.h index 173da9af3..16d1ce8cd 100644 --- a/eval/public/cel_value.h +++ b/eval/public/cel_value.h @@ -32,11 +32,12 @@ namespace api { namespace expr { namespace runtime { -using CelError = cel_base::Status; +using CelError = absl::Status; +// Break cyclic depdendencies for container types. class CelList; class CelMap; -class UnknownAttributeSet; +class UnknownSet; class CelValue { public: @@ -102,7 +103,7 @@ class CelValue { using ValueHolder = internal::ValueHolder< bool, int64_t, uint64_t, double, StringHolder, BytesHolder, const google::protobuf::Message *, absl::Duration, absl::Time, const CelList *, - const CelMap *, const UnknownAttributeSet *, const CelError *>; + const CelMap *, const UnknownSet *, const CelError *>; public: // Metafunction providing positions corresponding to specific @@ -123,7 +124,7 @@ class CelValue { kTimestamp = IndexOf::value, kList = IndexOf::value, kMap = IndexOf::value, - kUnknownSet = IndexOf::value, + kUnknownSet = IndexOf::value, kError = IndexOf::value, kAny // Special value. Used in function descriptors. }; @@ -197,7 +198,7 @@ class CelValue { return CelValue(value); } - static CelValue CreateUnknownSet(const UnknownAttributeSet *value) { + static CelValue CreateUnknownSet(const UnknownSet *value) { CheckNullPointer(value, Type::kUnknownSet); return CelValue(value); } @@ -270,8 +271,8 @@ class CelValue { // Returns stored const UnknownAttributeSet * value. // Fails if stored value type is not const UnknownAttributeSet *. - const UnknownAttributeSet *UnknownSetOrDie() const { - return GetValueOrDie(Type::kUnknownSet); + const UnknownSet *UnknownSetOrDie() const { + return GetValueOrDie(Type::kUnknownSet); } // Returns stored const CelError * value. @@ -304,7 +305,7 @@ class CelValue { bool IsMap() const { return value_.is(); } - bool IsUnknownSet() const { return value_.is(); } + bool IsUnknownSet() const { return value_.is(); } bool IsError() const { return value_.is(); } @@ -420,7 +421,7 @@ class CelMap { // parsed from. -1, if the position can not be determined. CelValue CreateErrorValue( google::protobuf::Arena *arena, absl::string_view message, - cel_base::StatusCode error_code = cel_base::StatusCode::kUnknown, + absl::StatusCode error_code = absl::StatusCode::kUnknown, int position = -1); CelValue CreateNoMatchingOverloadError(google::protobuf::Arena *arena); @@ -440,6 +441,18 @@ CelValue CreateUnknownValueError(google::protobuf::Arena *arena, // encountered a value marked as unknown in Activation unknown_paths. bool IsUnknownValueError(const CelValue &value); +// Returns error indicating the result of the function is unknown. This is used +// as a signal to create an unknown set if unknown function handling is opted +// into. +CelValue CreateUnknownFunctionResultError(google::protobuf::Arena *arena, + absl::string_view help_message); + +// Returns true if this is unknown value error indicating that evaluation +// called an extension function whose value is unknown for the given args. +// This is used as a signal to convert to an UnknownSet if the behavior is opted +// into. +bool IsUnknownFunctionResult(const CelValue &value); + // Returns set of unknown paths for unknown value error. The value must be // unknown error, see IsUnknownValueError() above, or it dies. std::set GetUnknownPathsSetOrDie(const CelValue &value); diff --git a/eval/public/cel_value_test.cc b/eval/public/cel_value_test.cc index 6b82e77c9..735dc6a00 100644 --- a/eval/public/cel_value_test.cc +++ b/eval/public/cel_value_test.cc @@ -6,6 +6,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" #include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_set.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" @@ -94,7 +95,7 @@ TEST(CelValueTest, TestType) { CelValue value_timestamp2 = CelValue::CreateMessage(&msg_timestamp, &arena); EXPECT_THAT(value_timestamp2.type(), Eq(CelValue::Type::kTimestamp)); - UnknownAttributeSet unknown_set; + UnknownSet unknown_set; CelValue value_unknown = CelValue::CreateUnknownSet(&unknown_set); EXPECT_THAT(value_unknown.type(), Eq(CelValue::Type::kUnknownSet)); } @@ -132,7 +133,7 @@ int CountTypeMatch(const CelValue& value) { const CelError* value_error; count += (value.GetValue(&value_error)) ? 1 : 0; - const UnknownAttributeSet* value_unknown; + const UnknownSet* value_unknown; count += (value.GetValue(&value_unknown)) ? 1 : 0; return count; @@ -215,7 +216,6 @@ TEST(CelValueTest, TestString) { TEST(CelValueTest, TestBytes) { constexpr char kTestStr0[] = "test0"; std::string v = kTestStr0; - absl::string_view sv(v); CelValue value = CelValue::CreateBytes(&v); // CelValue value = CelValue::CreateString("test"); @@ -304,14 +304,14 @@ TEST(CelValueTest, TestMap) { // This test verifies CelValue support of Unknown type. TEST(CelValueTest, TestUnknownSet) { - UnknownAttributeSet unknown_set; + UnknownSet unknown_set; CelValue value = CelValue::CreateUnknownSet(&unknown_set); EXPECT_TRUE(value.IsUnknownSet()); EXPECT_THAT(value.UnknownSetOrDie(), Eq(&unknown_set)); // test template getter - const UnknownAttributeSet* value2; + const UnknownSet* value2; EXPECT_TRUE(value.GetValue(&value2)); EXPECT_THAT(value2, Eq(&unknown_set)); EXPECT_THAT(CountTypeMatch(value), Eq(1)); @@ -585,6 +585,14 @@ TEST(CelValueTest, TestBytesWrapper) { EXPECT_EQ(value.BytesOrDie().value(), wrapper.value()); } +TEST(CelValueTest, UnknownFunctionResultErrors) { + ::google::protobuf::Arena arena; + + CelValue value = CreateUnknownFunctionResultError(&arena, "message"); + EXPECT_TRUE(value.IsError()); + EXPECT_TRUE(IsUnknownFunctionResult(value)); +} + } // namespace runtime } // namespace expr } // namespace api diff --git a/eval/public/extension_func_registrar.cc b/eval/public/extension_func_registrar.cc index 2611c24a7..1ac79e12c 100644 --- a/eval/public/extension_func_registrar.cc +++ b/eval/public/extension_func_registrar.cc @@ -7,8 +7,8 @@ namespace api { namespace expr { namespace runtime { -::cel_base::Status RegisterExtensionFunctions(CelFunctionRegistry*) { - return ::cel_base::OkStatus(); +absl::Status RegisterExtensionFunctions(CelFunctionRegistry*) { + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/extension_func_registrar.h b/eval/public/extension_func_registrar.h index 4eb68cefb..dcd5ebc4e 100644 --- a/eval/public/extension_func_registrar.h +++ b/eval/public/extension_func_registrar.h @@ -10,7 +10,7 @@ namespace expr { namespace runtime { // Register generic/widely used extension functions. -cel_base::Status RegisterExtensionFunctions(CelFunctionRegistry* registry); +absl::Status RegisterExtensionFunctions(CelFunctionRegistry* registry); } // namespace runtime } // namespace expr diff --git a/eval/public/extension_func_test.cc b/eval/public/extension_func_test.cc index 35b9c32c8..58f45633e 100644 --- a/eval/public/extension_func_test.cc +++ b/eval/public/extension_func_test.cc @@ -5,6 +5,7 @@ #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_function_registry.h" #include "eval/public/extension_func_registrar.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -23,8 +24,8 @@ class ExtensionTest : public ::testing::Test { ExtensionTest() {} void SetUp() override { - ASSERT_TRUE(RegisterBuiltinFunctions(®istry_).ok()); - ASSERT_TRUE(RegisterExtensionFunctions(®istry_).ok()); + ASSERT_OK(RegisterBuiltinFunctions(®istry_)); + ASSERT_OK(RegisterExtensionFunctions(®istry_)); } // Helper method to test string startsWith() function @@ -49,7 +50,7 @@ class ExtensionTest : public ::testing::Test { absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, &result_value, &arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); ASSERT_TRUE(result_value.IsBool()); ASSERT_EQ(result_value.BoolOrDie(), result); } @@ -79,7 +80,7 @@ class ExtensionTest : public ::testing::Test { absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); } // Helper method to test duration() function @@ -95,7 +96,7 @@ class ExtensionTest : public ::testing::Test { absl::Span arg_span(&args[0], args.size()); auto status = func->Evaluate(arg_span, result, arena); - ASSERT_TRUE(status.ok()); + ASSERT_OK(status); } // Function registry object @@ -170,7 +171,7 @@ TEST_F(ExtensionTest, TestTimestampFromString) { // Invalid timestamp - empty string. EXPECT_NO_FATAL_FAILURE(PerformTimestampConversion(&arena, "", &result)); ASSERT_TRUE(result.IsError()); - ASSERT_EQ(result.ErrorOrDie()->code(), cel_base::StatusCode::kInvalidArgument); + ASSERT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); // Invalid timestamp. EXPECT_NO_FATAL_FAILURE( @@ -203,7 +204,7 @@ TEST_F(ExtensionTest, TestDurationFromString) { // Invalid duration - empty string. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "", &result)); ASSERT_TRUE(result.IsError()); - ASSERT_EQ(result.ErrorOrDie()->code(), cel_base::StatusCode::kInvalidArgument); + ASSERT_EQ(result.ErrorOrDie()->code(), absl::StatusCode::kInvalidArgument); // Invalid duration. EXPECT_NO_FATAL_FAILURE(PerformDurationConversion(&arena, "100", &result)); diff --git a/eval/public/unknown_attribute_set.h b/eval/public/unknown_attribute_set.h index 87eedbe46..b3abdeeb2 100644 --- a/eval/public/unknown_attribute_set.h +++ b/eval/public/unknown_attribute_set.h @@ -19,29 +19,31 @@ class UnknownAttributeSet { UnknownAttributeSet& operator=(const UnknownAttributeSet& other) = default; UnknownAttributeSet() {} - UnknownAttributeSet( - const std::vector>& attributes) { + UnknownAttributeSet(const std::vector& attributes) { attributes_.reserve(attributes.size()); for (const auto& attr : attributes) { Add(attr); } } - std::vector> attributes() const { - return attributes_; + UnknownAttributeSet(const UnknownAttributeSet& set1, + const UnknownAttributeSet& set2) + : attributes_(set1.attributes()) { + attributes_.reserve(set1.attributes().size() + set2.attributes().size()); + for (const auto& attr : set2.attributes()) { + Add(attr); + } } + std::vector attributes() const { return attributes_; } + static UnknownAttributeSet Merge(const UnknownAttributeSet& set1, const UnknownAttributeSet& set2) { - UnknownAttributeSet attr_set = set1; - for (const auto& attr : set2.attributes()) { - attr_set.Add(attr); - } - return attr_set; + return UnknownAttributeSet(set1, set2); } private: - void Add(std::shared_ptr attribute) { + void Add(const CelAttribute* attribute) { if (!attribute) { return; } @@ -54,7 +56,7 @@ class UnknownAttributeSet { } // Attribute container. - std::vector> attributes_; + std::vector attributes_; }; } // namespace runtime diff --git a/eval/public/unknown_attribute_set_test.cc b/eval/public/unknown_attribute_set_test.cc index 7b98d6266..775628f4a 100644 --- a/eval/public/unknown_attribute_set_test.cc +++ b/eval/public/unknown_attribute_set_test.cc @@ -33,7 +33,7 @@ TEST(UnknownAttributeSetTest, TestCreate) { CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(true))})); - UnknownAttributeSet unknown_set({cel_attr}); + UnknownAttributeSet unknown_set({cel_attr.get()}); EXPECT_THAT(unknown_set.attributes().size(), Eq(1)); EXPECT_THAT(*(unknown_set.attributes()[0]), Eq(*cel_attr)); } @@ -74,8 +74,8 @@ TEST(UnknownAttributeSetTest, TestMergeSets) { CelAttributeQualifier::Create(CelValue::CreateUint64(2)), CelAttributeQualifier::Create(CelValue::CreateBool(false))})); - UnknownAttributeSet unknown_set1({cel_attr1, cel_attr2}); - UnknownAttributeSet unknown_set2({cel_attr1_copy, cel_attr3}); + UnknownAttributeSet unknown_set1({cel_attr1.get(), cel_attr2.get()}); + UnknownAttributeSet unknown_set2({cel_attr1_copy.get(), cel_attr3.get()}); UnknownAttributeSet unknown_set3 = UnknownAttributeSet::Merge(unknown_set1, unknown_set2); diff --git a/eval/public/unknown_function_result_set.cc b/eval/public/unknown_function_result_set.cc new file mode 100644 index 000000000..6e484da78 --- /dev/null +++ b/eval/public/unknown_function_result_set.cc @@ -0,0 +1,161 @@ +#include "eval/public/unknown_function_result_set.h" + +#include + +#include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +// Forward declare. +bool CelValueEqual(const CelValue lhs, const CelValue rhs); + +// Default to operator== +template +bool CelValueEqualImpl(T lhs, T rhs) { + return lhs == rhs; +} + +// List equality specialization. Test that the lists are in-order elementwise +// equal. +template <> +bool CelValueEqualImpl(const CelList* lhs, const CelList* rhs) { + if (lhs->size() != rhs->size()) { + return false; + } + for (int i = 0; i < rhs->size(); i++) { + if (!CelValueEqual(lhs->operator[](i), rhs->operator[](i))) { + return false; + } + } + return true; +} + +// Map equality specialization. Compare that two maps have exactly the same +// key/value pairs. +template <> +bool CelValueEqualImpl(const CelMap* lhs, const CelMap* rhs) { + if (lhs->size() != rhs->size()) { + return false; + } + const CelList* key_set = rhs->ListKeys(); + for (int i = 0; i < key_set->size(); i++) { + CelValue key = key_set->operator[](i); + CelValue rhs_value = rhs->operator[](key).value(); + auto maybe_lhs_value = lhs->operator[](key); + if (!maybe_lhs_value.has_value()) { + return false; + } + if (!CelValueEqual(maybe_lhs_value.value(), rhs_value)) { + return false; + } + } + return true; +} + +// Visitor for implementing comparing the underlying value that two CelValues +// are wrapping. The visitor unwraps the lhs then tries to get the rhs +// underlying value if it is the same type as the lhs. +struct LhsCompareVisitor { + CelValue rhs; + + LhsCompareVisitor(CelValue rhs) : rhs(rhs) {} + + template + bool operator()(T lhs_value) { + T rhs_value; + bool is_same_type = rhs.GetValue(&rhs_value); + if (!is_same_type) { + return false; + } + return CelValueEqualImpl(lhs_value, rhs_value); + } +}; + +// This is a slightly different implementation than provided for the cel +// evaluator. Differences are: +// +// - this implementation doesn't need to support error forwarding in the same +// way -- this should only be used for situations when we can invoke the +// function. i.e. the function must specify that it consumes errors and/or +// unknown sets for them to appear in the arg list. +// - this implementation defines equality between messages based on ptr identity +bool CelValueEqual(const CelValue lhs, const CelValue rhs) { + if (lhs.type() != rhs.type()) { + return false; + } + return lhs.Visit(LhsCompareVisitor(rhs)); +} + +// Tests that two descriptors are equal (name, receiver call style, arg types). +// +// Argument type Any is not treated specially. For example: +// {"f", false, {kAny}} != {"f", false, {kInt64}} +bool DescriptorEqual(const CelFunctionDescriptor& lhs, + const CelFunctionDescriptor& rhs) { + if (lhs.name() != rhs.name()) { + return false; + } + + if (lhs.receiver_style() != rhs.receiver_style()) { + return false; + } + + if (lhs.types() != rhs.types()) { + return false; + } + + return true; +} + +} // namespace + +bool UnknownFunctionResult::IsEqualTo( + const UnknownFunctionResult& other) const { + if (!DescriptorEqual(descriptor_, other.descriptor())) { + return false; + } + + if (arguments_.size() != other.arguments().size()) { + return false; + } + + for (size_t i = 0; i < arguments_.size(); i++) { + if (!CelValueEqual(arguments_[i], other.arguments()[i])) { + return false; + } + } + + return true; +} + +// Implementation for merge constructor. +UnknownFunctionResultSet::UnknownFunctionResultSet( + const UnknownFunctionResultSet& lhs, const UnknownFunctionResultSet& rhs) + : unknown_function_results_(lhs.unknown_function_results()) { + unknown_function_results_.reserve(lhs.unknown_function_results().size() + + rhs.unknown_function_results().size()); + for (const UnknownFunctionResult* call : rhs.unknown_function_results()) { + Add(call); + } +} + +void UnknownFunctionResultSet::Add(const UnknownFunctionResult* result) { + for (const UnknownFunctionResult* existing_result : + unknown_function_results()) { + if (result->IsEqualTo(*existing_result)) { + return; + } + } + unknown_function_results_.push_back(result); +} + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/unknown_function_result_set.h b/eval/public/unknown_function_result_set.h new file mode 100644 index 000000000..27c6cbc9b --- /dev/null +++ b/eval/public/unknown_function_result_set.h @@ -0,0 +1,75 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ + +#include + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "eval/public/cel_function.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Represents a function result that is unknown at the time of execution. This +// allows for lazy evaluation of expensive functions. +class UnknownFunctionResult { + public: + UnknownFunctionResult(const CelFunctionDescriptor& descriptor, int64_t expr_id, + const std::vector& arguments) + : descriptor_(descriptor), expr_id_(expr_id), arguments_(arguments) {} + + // The descriptor of the called function that return Unknown. + const CelFunctionDescriptor& descriptor() const { return descriptor_; } + + // The id of the |Expr| that triggered the function call step. Provided + // informationally -- if two different |Expr|s generate the same unknown call, + // they will be treated as the same unknown function result. + int64_t call_expr_id() const { return expr_id_; } + + // The arguments of the function call that generated the unknown. + const std::vector& arguments() const { return arguments_; } + + // Equality operator provided for set semantics. + // Compares descriptor then arguments elementwise. + bool IsEqualTo(const UnknownFunctionResult& other) const; + + private: + CelFunctionDescriptor descriptor_; + int64_t expr_id_; + std::vector arguments_; +}; + +// Represents a collection of unknown function results at a particular point in +// execution. Execution should advance further if this set of unknowns are +// provided. It may not advance if only a subset are provided. +// Set semantics use |IsEqualTo()| defined on |UnknownFunctionResult|. +class UnknownFunctionResultSet { + public: + // Empty set + UnknownFunctionResultSet() {} + + // Merge constructor -- effectively union(lhs, rhs). + UnknownFunctionResultSet(const UnknownFunctionResultSet& lhs, + const UnknownFunctionResultSet& rhs); + + // Initialize with a single UnknownFunctionResult. + UnknownFunctionResultSet(const UnknownFunctionResult* initial) + : unknown_function_results_{initial} {} + + const std::vector& unknown_function_results() + const { + return unknown_function_results_; + } + + private: + std::vector unknown_function_results_; + void Add(const UnknownFunctionResult* result); +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_FUNCTION_RESULT_SET_H_ diff --git a/eval/public/unknown_function_result_set_test.cc b/eval/public/unknown_function_result_set_test.cc new file mode 100644 index 000000000..35d438a62 --- /dev/null +++ b/eval/public/unknown_function_result_set_test.cc @@ -0,0 +1,487 @@ +#include "eval/public/unknown_function_result_set.h" + +#include + +#include + +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/empty.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/arena.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/time/clock.h" +#include "absl/time/time.h" +#include "absl/types/span.h" +#include "eval/eval/container_backed_list_impl.h" +#include "eval/eval/container_backed_map_impl.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_value.h" +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +using ::google::protobuf::ListValue; +using ::google::protobuf::Struct; +using ::google::protobuf::Arena; +using testing::Eq; +using testing::SizeIs; + +CelFunctionDescriptor kTwoInt("TwoInt", false, + {CelValue::Type::kInt64, CelValue::Type::kInt64}); + +CelFunctionDescriptor kOneInt("OneInt", false, {CelValue::Type::kInt64}); + +TEST(UnknownFunctionResult, ArgumentCapture) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + EXPECT_THAT(call1.arguments(), SizeIs(2)); + EXPECT_THAT(call1.arguments().at(0).Int64OrDie(), Eq(1)); +} + +TEST(UnknownFunctionResult, Equals) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call2( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + EXPECT_TRUE(call1.IsEqualTo(call2)); + + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + UnknownFunctionResult call4(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + EXPECT_TRUE(call3.IsEqualTo(call4)); +} + +TEST(UnknownFunctionResult, InequalDescriptor) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call2(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + EXPECT_FALSE(call1.IsEqualTo(call2)); + + CelFunctionDescriptor one_uint("OneInt", false, {CelValue::Type::kUint64}); + + UnknownFunctionResult call3(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + UnknownFunctionResult call4(one_uint, /*expr_id=*/0, + {CelValue::CreateUint64(1)}); + + EXPECT_FALSE(call3.IsEqualTo(call4)); +} + +TEST(UnknownFunctionResult, InequalArgs) { + UnknownFunctionResult call1( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call2( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(3)}); + + EXPECT_FALSE(call1.IsEqualTo(call2)); + + UnknownFunctionResult call3( + kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + UnknownFunctionResult call4(kTwoInt, /*expr_id=*/0, + {CelValue::CreateInt64(1)}); + + EXPECT_FALSE(call3.IsEqualTo(call4)); +} + +TEST(UnknownFunctionResult, ListsEqual) { + ContainerBackedListImpl cel_list_1(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + ContainerBackedListImpl cel_list_2(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_1)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_2)}); + + // [1, 2] == [1, 2] + EXPECT_TRUE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, ListsDifferentSizes) { + ContainerBackedListImpl cel_list_1(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + ContainerBackedListImpl cel_list_2(std::vector{ + CelValue::CreateInt64(1), + CelValue::CreateInt64(2), + CelValue::CreateInt64(3), + }); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_1)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_2)}); + + // [1, 2] == [1, 2, 3] + EXPECT_FALSE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, ListsDifferentMembers) { + ContainerBackedListImpl cel_list_1(std::vector{ + CelValue::CreateInt64(1), CelValue::CreateInt64(2)}); + + ContainerBackedListImpl cel_list_2(std::vector{ + CelValue::CreateInt64(2), CelValue::CreateInt64(2)}); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_1)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateList(&cel_list_2)}); + + // [1, 2] == [2, 2] + EXPECT_FALSE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, MapsEqual) { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; + + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values)); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_1.get())}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_2.get())}); + + // {1: 2, 2: 4} == {1: 2, 2: 4} + EXPECT_TRUE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, MapsDifferentSizes) { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}}; + + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + + std::vector> values2{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; + + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_1.get())}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_2.get())}); + + // {1: 2, 2: 4} == {1: 2, 2: 4, 3: 6} + EXPECT_FALSE(call1.IsEqualTo(call2)); +} + +TEST(UnknownFunctionResult, MapsDifferentElements) { + std::vector> values{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(3), CelValue::CreateInt64(6)}}; + + auto cel_map_1 = CreateContainerBackedMap(absl::MakeSpan(values)); + + std::vector> values2{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(4), CelValue::CreateInt64(8)}}; + + auto cel_map_2 = CreateContainerBackedMap(absl::MakeSpan(values2)); + + std::vector> values3{ + {CelValue::CreateInt64(1), CelValue::CreateInt64(2)}, + {CelValue::CreateInt64(2), CelValue::CreateInt64(4)}, + {CelValue::CreateInt64(3), CelValue::CreateInt64(5)}}; + + auto cel_map_3 = CreateContainerBackedMap(absl::MakeSpan(values3)); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_1.get())}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_2.get())}); + UnknownFunctionResult call3(desc, /*expr_id=*/0, + {CelValue::CreateMap(cel_map_3.get())}); + + // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 4: 8} + EXPECT_FALSE(call1.IsEqualTo(call2)); + // {1: 2, 2: 4, 3: 6} == {1: 2, 2: 4, 3: 5} + EXPECT_FALSE(call1.IsEqualTo(call3)); +} + +TEST(UnknownFunctionResult, Messages) { + protobuf::Empty message1; + protobuf::Empty message2; + google::protobuf::Arena arena; + + CelFunctionDescriptor desc("OneMessage", false, {CelValue::Type::kMessage}); + + UnknownFunctionResult call1(desc, /*expr_id=*/0, + {CelValue::CreateMessage(&message1, &arena)}); + UnknownFunctionResult call2(desc, /*expr_id=*/0, + {CelValue::CreateMessage(&message2, &arena)}); + UnknownFunctionResult call3(desc, /*expr_id=*/0, + {CelValue::CreateMessage(&message1, &arena)}); + + // &message1 == &message2 + EXPECT_FALSE(call1.IsEqualTo(call2)); + + // &message1 == &message1 + EXPECT_TRUE(call1.IsEqualTo(call3)); +} + +TEST(UnknownFunctionResult, AnyDescriptor) { + CelFunctionDescriptor anyDesc("OneAny", false, {CelValue::Type::kAny}); + + UnknownFunctionResult callAnyInt1(anyDesc, /*expr_id=*/0, + {CelValue::CreateInt64(2)}); + UnknownFunctionResult callInt(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(2)}); + + UnknownFunctionResult callAnyInt2(anyDesc, /*expr_id=*/0, + {CelValue::CreateInt64(2)}); + UnknownFunctionResult callAnyUint(anyDesc, /*expr_id=*/0, + {CelValue::CreateUint64(2)}); + + EXPECT_FALSE(callAnyInt1.IsEqualTo(callInt)); + EXPECT_FALSE(callAnyInt1.IsEqualTo(callAnyUint)); + EXPECT_TRUE(callAnyInt1.IsEqualTo(callAnyInt2)); +} + +TEST(UnknownFunctionResult, Strings) { + CelFunctionDescriptor desc("OneString", false, {CelValue::Type::kString}); + + UnknownFunctionResult callStringSmile(desc, /*expr_id=*/0, + {CelValue::CreateStringView("😁")}); + UnknownFunctionResult callStringFrown(desc, /*expr_id=*/0, + {CelValue::CreateStringView("🙁")}); + UnknownFunctionResult callStringSmile2(desc, /*expr_id=*/0, + {CelValue::CreateStringView("😁")}); + + EXPECT_TRUE(callStringSmile.IsEqualTo(callStringSmile2)); + EXPECT_FALSE(callStringSmile.IsEqualTo(callStringFrown)); +} + +TEST(UnknownFunctionResult, DurationHandling) { + google::protobuf::Arena arena; + absl::Duration duration1 = absl::Seconds(5); + protobuf::Duration duration2; + duration2.set_seconds(5); + + CelFunctionDescriptor durationDesc("OneDuration", false, + {CelValue::Type::kDuration}); + + UnknownFunctionResult callDuration1(durationDesc, /*expr_id=*/0, + {CelValue::CreateDuration(duration1)}); + UnknownFunctionResult callDuration2( + durationDesc, /*expr_id=*/0, + {CelValue::CreateMessage(&duration2, &arena)}); + UnknownFunctionResult callDuration3(durationDesc, /*expr_id=*/0, + {CelValue::CreateDuration(&duration2)}); + + EXPECT_TRUE(callDuration1.IsEqualTo(callDuration2)); + EXPECT_TRUE(callDuration1.IsEqualTo(callDuration3)); +} + +TEST(UnknownFunctionResult, TimestampHandling) { + google::protobuf::Arena arena; + absl::Time ts1 = absl::FromUnixMillis(1000); + protobuf::Timestamp ts2; + ts2.set_seconds(1); + + CelFunctionDescriptor timestampDesc("OneTimestamp", false, + {CelValue::Type::kTimestamp}); + + UnknownFunctionResult callTimestamp1(timestampDesc, /*expr_id=*/0, + {CelValue::CreateTimestamp(ts1)}); + UnknownFunctionResult callTimestamp2(timestampDesc, /*expr_id=*/0, + {CelValue::CreateMessage(&ts2, &arena)}); + UnknownFunctionResult callTimestamp3(timestampDesc, /*expr_id=*/0, + {CelValue::CreateTimestamp(&ts2)}); + + EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp2)); + EXPECT_TRUE(callTimestamp1.IsEqualTo(callTimestamp3)); +} + +// This tests that the conversion and different map backing implementations are +// compatible with the equality tests. +TEST(UnknownFunctionResult, ProtoStructTreatedAsMap) { + Arena arena; + + const std::vector kFields = {"field1", "field2", "field3"}; + + Struct value_struct; + + auto& value1 = (*value_struct.mutable_fields())[kFields[0]]; + value1.set_bool_value(true); + + auto& value2 = (*value_struct.mutable_fields())[kFields[1]]; + value2.set_number_value(1.0); + + auto& value3 = (*value_struct.mutable_fields())[kFields[2]]; + value3.set_string_value("test"); + + CelValue proto_struct = CelValue::CreateMessage(&value_struct, &arena); + ASSERT_TRUE(proto_struct.IsMap()); + + std::vector> values{ + {CelValue::CreateStringView(kFields[2]), + CelValue::CreateStringView("test")}, + {CelValue::CreateStringView(kFields[1]), CelValue::CreateDouble(1.0)}, + {CelValue::CreateStringView(kFields[0]), CelValue::CreateBool(true)}}; + + auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); + + CelValue cel_map = CelValue::CreateMap(backing_map.get()); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, + {proto_struct}); + UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); + + EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); +} + +// This tests that the conversion and different map backing implementations are +// compatible with the equality tests. +TEST(UnknownFunctionResult, ProtoListTreatedAsList) { + Arena arena; + + ListValue list_value; + + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + + CelValue proto_list = CelValue::CreateMessage(&list_value, &arena); + ASSERT_TRUE(proto_list.IsList()); + + std::vector list_values{CelValue::CreateBool(true), + CelValue::CreateDouble(1.0), + CelValue::CreateStringView("test")}; + + ContainerBackedListImpl list_backing(list_values); + + CelValue cel_list = CelValue::CreateList(&list_backing); + + CelFunctionDescriptor desc("OneList", false, {CelValue::Type::kList}); + + UnknownFunctionResult proto_list_result(desc, /*expr_id=*/0, {proto_list}); + UnknownFunctionResult cel_list_result(desc, /*expr_id=*/0, {cel_list}); + + EXPECT_TRUE(cel_list_result.IsEqualTo(proto_list_result)); +} + +TEST(UnknownFunctionResult, NestedProtoTypes) { + Arena arena; + + ListValue list_value; + + list_value.add_values()->set_bool_value(true); + list_value.add_values()->set_number_value(1.0); + list_value.add_values()->set_string_value("test"); + + std::vector list_values{CelValue::CreateBool(true), + CelValue::CreateDouble(1.0), + CelValue::CreateStringView("test")}; + + ContainerBackedListImpl list_backing(list_values); + + CelValue cel_list = CelValue::CreateList(&list_backing); + + Struct value_struct; + + *(value_struct.mutable_fields()->operator[]("field").mutable_list_value()) = + list_value; + + std::vector> values{ + {CelValue::CreateStringView("field"), cel_list}}; + + auto backing_map = CreateContainerBackedMap(absl::MakeSpan(values)); + + CelValue cel_map = CelValue::CreateMap(backing_map.get()); + CelValue proto_map = CelValue::CreateMessage(&value_struct, &arena); + + CelFunctionDescriptor desc("OneMap", false, {CelValue::Type::kMap}); + + UnknownFunctionResult cel_map_result(desc, /*expr_id=*/0, {cel_map}); + UnknownFunctionResult proto_struct_result(desc, /*expr_id=*/0, {proto_map}); + + EXPECT_TRUE(proto_struct_result.IsEqualTo(cel_map_result)); +} + +UnknownFunctionResult MakeUnknown(int64_t i) { + return UnknownFunctionResult(kOneInt, /*expr_id=*/0, + {CelValue::CreateInt64(i)}); +} + +testing::Matcher UnknownMatches( + const UnknownFunctionResult& obj) { + return testing::Truly([&](const UnknownFunctionResult* to_match) { + return obj.IsEqualTo(*to_match); + }); +} + +TEST(UnknownFunctionResultSet, Merge) { + UnknownFunctionResult a = MakeUnknown(1); + UnknownFunctionResult b = MakeUnknown(2); + UnknownFunctionResult c = MakeUnknown(3); + UnknownFunctionResult d = MakeUnknown(1); + + UnknownFunctionResultSet a1(&a); + UnknownFunctionResultSet b1(&b); + UnknownFunctionResultSet c1(&c); + UnknownFunctionResultSet d1(&d); + + UnknownFunctionResultSet ab(a1, b1); + UnknownFunctionResultSet cd(c1, d1); + + UnknownFunctionResultSet merged(ab, cd); + + EXPECT_THAT(merged.unknown_function_results(), SizeIs(3)); + EXPECT_THAT(merged.unknown_function_results(), + testing::UnorderedElementsAre( + UnknownMatches(a), UnknownMatches(b), UnknownMatches(c))); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/unknown_set.h b/eval/public/unknown_set.h new file mode 100644 index 000000000..3b7168afe --- /dev/null +++ b/eval/public/unknown_set.h @@ -0,0 +1,52 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ + +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_function_result_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +// Class representing a collection of unknowns from a single evaluation pass of +// a CEL expression. +class UnknownSet { + public: + // Initilization specifying subcontainers + explicit UnknownSet( + const google::api::expr::runtime::UnknownAttributeSet& attrs) + : unknown_attributes_(attrs) {} + explicit UnknownSet(const UnknownFunctionResultSet& function_results) + : unknown_function_results_(function_results) {} + UnknownSet(const UnknownAttributeSet& attrs, + const UnknownFunctionResultSet& function_results) + : unknown_attributes_(attrs), + unknown_function_results_(function_results) {} + // Initialization for empty set + UnknownSet() {} + // Merge constructor + UnknownSet(const UnknownSet& set1, const UnknownSet& set2) + : unknown_attributes_(set1.unknown_attributes(), + set2.unknown_attributes()), + unknown_function_results_(set1.unknown_function_results(), + set2.unknown_function_results()) {} + + const UnknownAttributeSet& unknown_attributes() const { + return unknown_attributes_; + } + const UnknownFunctionResultSet& unknown_function_results() const { + return unknown_function_results_; + } + + private: + UnknownAttributeSet unknown_attributes_; + UnknownFunctionResultSet unknown_function_results_; +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_UNKNOWN_SET_H_ diff --git a/eval/public/unknown_set_test.cc b/eval/public/unknown_set_test.cc new file mode 100644 index 000000000..3e4c06cda --- /dev/null +++ b/eval/public/unknown_set_test.cc @@ -0,0 +1,129 @@ +#include "eval/public/unknown_set.h" + +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "google/protobuf/arena.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/unknown_attribute_set.h" +#include "eval/public/unknown_function_result_set.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +using ::google::protobuf::Arena; +using testing::IsEmpty; +using testing::UnorderedElementsAre; + +UnknownFunctionResultSet MakeFunctionResult(Arena* arena, int64_t id) { + CelFunctionDescriptor desc("OneInt", false, {CelValue::Type::kInt64}); + std::vector call_args{CelValue::CreateInt64(id)}; + const auto* function_result = Arena::Create( + arena, desc, /*expr_id=*/0, call_args); + return UnknownFunctionResultSet(function_result); +} + +UnknownAttributeSet MakeAttribute(Arena* arena, int64_t id) { + google::api::expr::v1alpha1::Expr expr; + expr.mutable_ident_expr()->set_name("x"); + + std::vector attr_trail{ + CelAttributeQualifier::Create(CelValue::CreateInt64(id))}; + + const auto* attr = Arena::Create(arena, expr, attr_trail); + return UnknownAttributeSet({attr}); +} + +MATCHER_P(UnknownAttributeIs, id, "") { + const CelAttribute* attr = arg; + if (attr->qualifier_path().size() != 1) { + return false; + } + auto maybe_qualifier = attr->qualifier_path()[0].GetInt64Key(); + if (!maybe_qualifier.has_value()) { + return false; + } + return maybe_qualifier.value() == id; +} + +MATCHER_P(UnknownFunctionResultIs, id, "") { + const UnknownFunctionResult* result = arg; + if (result->arguments().size() != 1) { + return false; + } + if (!result->arguments()[0].IsInt64()) { + return false; + } + return result->arguments()[0].Int64OrDie() == id; +} + +TEST(UnknownSet, AttributesMerge) { + Arena arena; + UnknownSet a(MakeAttribute(&arena, 1)); + UnknownSet b(MakeAttribute(&arena, 2)); + UnknownSet c(MakeAttribute(&arena, 2)); + UnknownSet d(a, b); + UnknownSet e(c, d); + + EXPECT_THAT( + d.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); + EXPECT_THAT( + e.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); +} + +TEST(UnknownSet, FunctionsMerge) { + Arena arena; + + UnknownSet a(MakeFunctionResult(&arena, 1)); + UnknownSet b(MakeFunctionResult(&arena, 2)); + UnknownSet c(MakeFunctionResult(&arena, 2)); + UnknownSet d(a, b); + UnknownSet e(c, d); + + EXPECT_THAT(d.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); + EXPECT_THAT(e.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); +} + +TEST(UnknownSet, DefaultEmpty) { + UnknownSet empty_set; + EXPECT_THAT(empty_set.unknown_attributes().attributes(), IsEmpty()); + EXPECT_THAT(empty_set.unknown_function_results().unknown_function_results(), + IsEmpty()); +} + +TEST(UnknownSet, MixedMerges) { + Arena arena; + + UnknownSet a(MakeAttribute(&arena, 1), MakeFunctionResult(&arena, 1)); + UnknownSet b(MakeFunctionResult(&arena, 2)); + UnknownSet c(MakeAttribute(&arena, 2)); + UnknownSet d(a, b); + UnknownSet e(c, d); + + EXPECT_THAT(d.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1))); + EXPECT_THAT(d.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); + EXPECT_THAT( + e.unknown_attributes().attributes(), + UnorderedElementsAre(UnknownAttributeIs(1), UnknownAttributeIs(2))); + EXPECT_THAT(e.unknown_function_results().unknown_function_results(), + UnorderedElementsAre(UnknownFunctionResultIs(1), + UnknownFunctionResultIs(2))); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/eval/public/value_export_util.cc b/eval/public/value_export_util.cc index b557d6804..de72c135e 100644 --- a/eval/public/value_export_util.cc +++ b/eval/public/value_export_util.cc @@ -4,7 +4,6 @@ #include "google/protobuf/util/time_util.h" #include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -20,7 +19,7 @@ using google::protobuf::FieldDescriptor; using google::protobuf::Message; using google::protobuf::util::TimeUtil; -cel_base::Status KeyAsString(const CelValue& value, std::string* key) { +absl::Status KeyAsString(const CelValue& value, std::string* key) { switch (value.type()) { case CelValue::Type::kInt64: { *key = absl::StrCat(value.Int64OrDie()); @@ -35,16 +34,18 @@ cel_base::Status KeyAsString(const CelValue& value, std::string* key) { value.StringOrDie().value().size()); break; } - default: { return cel_base::InvalidArgumentError("Unsupported map type"); } + default: { + return absl::InvalidArgumentError("Unsupported map type"); + } } - return cel_base::OkStatus(); + return absl::OkStatus(); } // Export content of CelValue as google.protobuf.Value. -cel_base::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { +absl::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) { if (in_value.IsNull()) { out_value->set_null_value(google::protobuf::NULL_VALUE); - return cel_base::OkStatus(); + return absl::OkStatus(); } switch (in_value.type()) { case CelValue::Type::kBool: { @@ -92,13 +93,13 @@ cel_base::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) auto status = google::protobuf::util::MessageToJsonString(*in_value.MessageOrDie(), &json, json_options); if (!status.ok()) { - return cel_base::InternalError(status.ToString()); + return absl::InternalError(status.ToString()); } google::protobuf::util::JsonParseOptions json_parse_options; status = google::protobuf::util::JsonStringToMessage(json, out_value, json_parse_options); if (!status.ok()) { - return cel_base::InternalError(status.ToString()); + return absl::InternalError(status.ToString()); } break; } @@ -135,9 +136,11 @@ cel_base::Status ExportAsProtoValue(const CelValue& in_value, Value* out_value) } break; } - default: { return cel_base::InvalidArgumentError("Unsupported value type"); } + default: { + return absl::InvalidArgumentError("Unsupported value type"); + } } - return cel_base::OkStatus(); + return absl::OkStatus(); } } // namespace runtime diff --git a/eval/public/value_export_util.h b/eval/public/value_export_util.h index 39018b3d7..6fbf9f8c4 100644 --- a/eval/public/value_export_util.h +++ b/eval/public/value_export_util.h @@ -1,10 +1,9 @@ #ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ #define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_VALUE_EXPORT_UTIL_H_ -#include "eval/public/cel_value.h" - #include "google/protobuf/struct.pb.h" -#include "base/status.h" +#include "absl/status/status.h" +#include "eval/public/cel_value.h" namespace google { namespace api { @@ -16,7 +15,7 @@ namespace runtime { // - exports integer values as doubles (Value.number_value); // - exports integer keys in maps as strings; // - handles Duration and Timestamp as generic messages. -cel_base::Status ExportAsProtoValue(const CelValue &in_value, +absl::Status ExportAsProtoValue(const CelValue &in_value, google::protobuf::Value *out_value); } // namespace runtime diff --git a/eval/public/value_export_util_test.cc b/eval/public/value_export_util_test.cc index e988ebe2d..7dd7773bd 100644 --- a/eval/public/value_export_util_test.cc +++ b/eval/public/value_export_util_test.cc @@ -9,6 +9,7 @@ #include "eval/eval/container_backed_map_impl.h" #include "eval/testutil/test_message.pb.h" #include "testutil/util.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -27,7 +28,7 @@ using google::protobuf::Arena; TEST(ValueExportUtilTest, ConvertBoolValue) { CelValue cel_value = CelValue::CreateBool(true); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kBoolValue); EXPECT_EQ(value.bool_value(), true); } @@ -35,7 +36,7 @@ TEST(ValueExportUtilTest, ConvertBoolValue) { TEST(ValueExportUtilTest, ConvertInt64Value) { CelValue cel_value = CelValue::CreateInt64(-1); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), -1); } @@ -43,7 +44,7 @@ TEST(ValueExportUtilTest, ConvertInt64Value) { TEST(ValueExportUtilTest, ConvertUint64Value) { CelValue cel_value = CelValue::CreateUint64(1); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), 1); } @@ -51,7 +52,7 @@ TEST(ValueExportUtilTest, ConvertUint64Value) { TEST(ValueExportUtilTest, ConvertDoubleValue) { CelValue cel_value = CelValue::CreateDouble(1.3); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kNumberValue); EXPECT_DOUBLE_EQ(value.number_value(), 1.3); } @@ -60,7 +61,7 @@ TEST(ValueExportUtilTest, ConvertStringValue) { std::string test = "test"; CelValue cel_value = CelValue::CreateString(&test); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "test"); } @@ -69,7 +70,7 @@ TEST(ValueExportUtilTest, ConvertBytesValue) { std::string test = "test"; CelValue cel_value = CelValue::CreateBytes(&test); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); // Check that the result is BASE64 encoded. EXPECT_EQ(value.string_value(), "dGVzdA=="); @@ -81,7 +82,7 @@ TEST(ValueExportUtilTest, ConvertDurationValue) { duration.set_nanos(3); CelValue cel_value = CelValue::CreateDuration(&duration); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "2.000000003s"); } @@ -92,7 +93,7 @@ TEST(ValueExportUtilTest, ConvertTimestampValue) { timestamp.set_nanos(3); CelValue cel_value = CelValue::CreateTimestamp(×tamp); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStringValue); EXPECT_EQ(value.string_value(), "2001-09-09T01:46:40.000000003Z"); } @@ -103,7 +104,7 @@ TEST(ValueExportUtilTest, ConvertStructMessage) { Arena arena; CelValue cel_value = CelValue::CreateMessage(&struct_msg, &arena); Value value; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); EXPECT_THAT(value.struct_value(), testutil::EqualsProto(struct_msg)); } @@ -116,7 +117,7 @@ TEST(ValueExportUtilTest, ConvertValueMessage) { Arena arena; CelValue cel_value = CelValue::CreateMessage(&value_in, &arena); Value value_out; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value_out).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value_out)); EXPECT_THAT(value_in, testutil::EqualsProto(value_out)); } @@ -127,7 +128,7 @@ TEST(ValueExportUtilTest, ConvertListValueMessage) { Arena arena; CelValue cel_value = CelValue::CreateMessage(&list_value, &arena); Value value_out; - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value_out).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value_out)); EXPECT_THAT(list_value, testutil::EqualsProto(value_out.list_value())); } @@ -140,7 +141,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedBoolValue) { msg->add_bool_list(false); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("bool_list"); @@ -159,7 +160,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedInt32Value) { msg->add_int32_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("int32_list"); @@ -178,7 +179,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedInt64Value) { msg->add_int64_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("int64_list"); @@ -197,7 +198,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedUint64Value) { msg->add_uint64_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("uint64_list"); @@ -216,7 +217,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedDoubleValue) { msg->add_double_list(3); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("double_list"); @@ -235,7 +236,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedStringValue) { msg->add_string_list("test2"); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("string_list"); @@ -254,7 +255,7 @@ TEST(ValueExportUtilTest, ConvertRepeatedBytesValue) { msg->add_bytes_list("test2"); CelValue cel_value = CelValue::CreateMessage(msg, &arena); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); Value list_value = value.struct_value().fields().at("bytes_list"); @@ -274,7 +275,7 @@ TEST(ValueExportUtilTest, ConvertCelList) { CelList *cel_list = Arena::Create(&arena, values); CelValue cel_value = CelValue::CreateList(cel_list); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kListValue); EXPECT_DOUBLE_EQ(value.list_value().values(0).number_value(), 2); @@ -299,7 +300,7 @@ TEST(ValueExportUtilTest, ConvertCelMapWithStringKey) { absl::Span>(map_entries)); CelValue cel_value = CelValue::CreateMap(cel_map.get()); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); const auto &fields = value.struct_value().fields(); @@ -326,7 +327,7 @@ TEST(ValueExportUtilTest, ConvertCelMapWithInt64Key) { absl::Span>(map_entries)); CelValue cel_value = CelValue::CreateMap(cel_map.get()); - EXPECT_TRUE(ExportAsProtoValue(cel_value, &value).ok()); + EXPECT_OK(ExportAsProtoValue(cel_value, &value)); EXPECT_EQ(value.kind_case(), Value::KindCase::kStructValue); const auto &fields = value.struct_value().fields(); diff --git a/eval/tests/BUILD b/eval/tests/BUILD index 273f1a2eb..ce8a80822 100644 --- a/eval/tests/BUILD +++ b/eval/tests/BUILD @@ -41,6 +41,7 @@ cc_test( ], copts = ["-std=c++14"], deps = [ + "//base:status_macros", "//eval/public:activation", "//eval/public:builtin_func_registrar", "//eval/public:cel_expr_builder_factory", @@ -48,12 +49,35 @@ cc_test( "//eval/public:cel_value", "//eval/testutil:test_message_cc_proto", "@com_github_google_googletest//:gtest_main", - "@com_google_absl//absl/strings", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], ) +cc_test( + name = "unknowns_end_to_end_test", + size = "small", + srcs = [ + "unknowns_end_to_end_test.cc", + ], + copts = ["-std=c++14"], + deps = [ + "//base:status_macros", + "//eval/public:activation", + "//eval/public:builtin_func_registrar", + "//eval/public:cel_attribute", + "//eval/public:cel_expr_builder_factory", + "//eval/public:cel_expression", + "//eval/public:cel_function", + "//eval/public:cel_options", + "//eval/public:cel_value", + "//eval/public:unknown_set", + "@com_github_google_googletest//:gtest_main", + "@com_google_absl//absl/strings", + "@com_google_protobuf//:protobuf", + ], +) + proto_library( name = "request_context_protos", srcs = [ @@ -65,3 +89,15 @@ cc_proto_library( name = "request_context_cc_proto", deps = [":request_context_protos"], ) + +cc_library( + name = "mock_cel_expression", + testonly = 1, + hdrs = ["mock_cel_expression.h"], + copts = ["-std=c++14"], + deps = [ + "//eval/public:activation", + "//eval/public:cel_expression", + "@com_github_google_googletest//:gtest_main", + ], +) diff --git a/eval/tests/end_to_end_test.cc b/eval/tests/end_to_end_test.cc index 7c267b94e..c95f919b3 100644 --- a/eval/tests/end_to_end_test.cc +++ b/eval/tests/end_to_end_test.cc @@ -8,6 +8,7 @@ #include "eval/public/cel_expression.h" #include "eval/public/cel_value.h" #include "eval/testutil/test_message.pb.h" +#include "base/status_macros.h" namespace google { namespace api { @@ -49,12 +50,12 @@ TEST(EndToEndTest, SimpleOnePlusOne) { std::unique_ptr builder = CreateCelExpressionBuilder(); // Builtin registration. - ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - ASSERT_TRUE(cel_expression_status.ok()); + ASSERT_OK(cel_expression_status); auto cel_expression = std::move(cel_expression_status.ValueOrDie()); @@ -68,7 +69,7 @@ TEST(EndToEndTest, SimpleOnePlusOne) { // Run evaluation. auto eval_status = cel_expression->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); @@ -133,12 +134,12 @@ TEST(EndToEndTest, EmptyStringCompare) { std::unique_ptr builder = CreateCelExpressionBuilder(); // Builtin registration. - ASSERT_TRUE(RegisterBuiltinFunctions(builder->GetRegistry()).ok()); + ASSERT_OK(RegisterBuiltinFunctions(builder->GetRegistry())); // Create CelExpression from AST (Expr object). auto cel_expression_status = builder->CreateExpression(&expr, &source_info); - ASSERT_TRUE(cel_expression_status.ok()); + ASSERT_OK(cel_expression_status); auto cel_expression = std::move(cel_expression_status.ValueOrDie()); @@ -158,7 +159,7 @@ TEST(EndToEndTest, EmptyStringCompare) { // Run evaluation. auto eval_status = cel_expression->Evaluate(activation, &arena); - ASSERT_TRUE(eval_status.ok()); + ASSERT_OK(eval_status); CelValue result = eval_status.ValueOrDie(); diff --git a/eval/tests/mock_cel_expression.h b/eval/tests/mock_cel_expression.h new file mode 100644 index 000000000..ba0ab2041 --- /dev/null +++ b/eval/tests/mock_cel_expression.h @@ -0,0 +1,44 @@ +#ifndef THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ +#define THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ + +#include + +#include "gmock/gmock.h" +#include "eval/public/activation.h" +#include "eval/public/cel_expression.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { + +class MockCelExpression : public CelExpression { + public: + MOCK_CONST_METHOD1(InitializeState, + std::unique_ptr(google::protobuf::Arena* arena)); + + MOCK_CONST_METHOD2( + Evaluate, ::cel_base::StatusOr(const BaseActivation& activation, + google::protobuf::Arena* arena)); + + MOCK_CONST_METHOD2( + Evaluate, ::cel_base::StatusOr(const BaseActivation& activation, + CelEvaluationState* state)); + + MOCK_CONST_METHOD3( + Trace, ::cel_base::StatusOr(const BaseActivation& activation, + google::protobuf::Arena* arena, + CelEvaluationListener callback)); + + MOCK_CONST_METHOD3( + Trace, ::cel_base::StatusOr(const BaseActivation& activation, + CelEvaluationState* state, + CelEvaluationListener callback)); +}; + +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google + +#endif // THIRD_PARTY_CEL_CPP_EVAL_TESTS_MOCK_CEL_EXPRESION_H_ diff --git a/eval/tests/unknowns_end_to_end_test.cc b/eval/tests/unknowns_end_to_end_test.cc new file mode 100644 index 000000000..0a5edefe3 --- /dev/null +++ b/eval/tests/unknowns_end_to_end_test.cc @@ -0,0 +1,305 @@ +#include + +#include "google/protobuf/arena.h" +#include "google/protobuf/text_format.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/strings/string_view.h" +#include "eval/public/activation.h" +#include "eval/public/builtin_func_registrar.h" +#include "eval/public/cel_attribute.h" +#include "eval/public/cel_expr_builder_factory.h" +#include "eval/public/cel_expression.h" +#include "eval/public/cel_function.h" +#include "eval/public/cel_options.h" +#include "eval/public/cel_value.h" +#include "eval/public/unknown_set.h" +#include "base/status_macros.h" + +namespace google { +namespace api { +namespace expr { +namespace runtime { +namespace { + +using ::google::protobuf::Arena; +using testing::ElementsAre; + +// var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') +constexpr char kExprTextproto[] = R"pb( + id: 13 + call_expr { + function: "_||_" + args { + id: 6 + call_expr { + function: "_&&_" + args { + id: 2 + call_expr { + function: "_>_" + args { + id: 1 + ident_expr { name: "var1" } + } + args { + id: 3 + const_expr { int64_value: 3 } + } + } + } + args { + id: 4 + call_expr { + function: "F1" + args { + id: 5 + const_expr { string_value: "arg1" } + } + } + } + } + } + args { + id: 12 + call_expr { + function: "_&&_" + args { + id: 8 + call_expr { + function: "_>_" + args { + id: 7 + ident_expr { name: "var2" } + } + args { + id: 9 + const_expr { int64_value: 3 } + } + } + } + args { + id: 10 + call_expr { + function: "F2" + args { + id: 11 + const_expr { string_value: "arg2" } + } + } + } + } + } + })pb"; + +enum class FunctionResponse { kUnknown, kTrue, kFalse }; + +CelFunctionDescriptor CreateDescriptor(absl::string_view name) { + return CelFunctionDescriptor(std::string(name), false, + {CelValue::Type::kString}); +} + +class FunctionImpl : public CelFunction { + public: + FunctionImpl(absl::string_view name, FunctionResponse response) + : CelFunction(CreateDescriptor(name)), response_(response) {} + + absl::Status Evaluate(absl::Span arguments, CelValue* result, + Arena* arena) const override { + switch (response_) { + case FunctionResponse::kUnknown: + *result = CreateUnknownFunctionResultError(arena, "help message"); + break; + case FunctionResponse::kTrue: + *result = CelValue::CreateBool(true); + break; + case FunctionResponse::kFalse: + *result = CelValue::CreateBool(false); + break; + } + return absl::OkStatus(); + } + + private: + FunctionResponse response_; +}; + +// Text fixture for unknowns. Holds on to state needed for execution to work +// correctly. +class UnknownsTest : public testing::Test { + public: + void PrepareBuilder(UnknownProcessingOptions opts) { + InterpreterOptions options; + options.unknown_processing = opts; + builder_ = CreateCelExpressionBuilder(options); + ASSERT_OK(RegisterBuiltinFunctions(builder_->GetRegistry())); + ASSERT_OK( + builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F1"))); + ASSERT_OK( + builder_->GetRegistry()->RegisterLazyFunction(CreateDescriptor("F2"))); + ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(kExprTextproto, &expr_)) + << "error parsing expr"; + } + + protected: + Arena arena_; + Activation activation_; + std::unique_ptr builder_; + google::api::expr::v1alpha1::Expr expr_; +}; + +MATCHER_P2(FunctionCallIs, fn_name, fn_arg, "") { + const UnknownFunctionResult* result = arg; + return result->arguments().size() == 1 && result->arguments()[0].IsString() && + result->arguments()[0].StringOrDie().value() == fn_arg && + result->descriptor().name() == fn_name; +} + +MATCHER_P(AttributeIs, attr, "") { + const CelAttribute* result = arg; + return result->variable().ident_expr().name() == attr; +} + +TEST_F(UnknownsTest, NoUnknowns) { + PrepareBuilder(UnknownProcessingOptions::kDisabled); + // activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", + // {})}); + activation_.InsertValue("var1", CelValue::CreateInt64(3)); + activation_.InsertValue("var2", CelValue::CreateInt64(5)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kFalse))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kTrue))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsBool()); + EXPECT_TRUE(response.BoolOrDie()); +} + +TEST_F(UnknownsTest, UnknownAttributes) { + PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); + activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); + activation_.InsertValue("var2", CelValue::CreateInt64(3)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kTrue))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kFalse))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsUnknownSet()); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + ElementsAre(AttributeIs("var1"))); +} + +TEST_F(UnknownsTest, UnknownAttributesPruning) { + PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); + activation_.set_unknown_attribute_patterns({CelAttributePattern("var1", {})}); + activation_.InsertValue("var2", CelValue::CreateInt64(5)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kTrue))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kTrue))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsBool()); + EXPECT_TRUE(response.BoolOrDie()); +} + +TEST_F(UnknownsTest, UnknownFunctionsWithoutOptionError) { + PrepareBuilder(UnknownProcessingOptions::kAttributeOnly); + activation_.InsertValue("var1", CelValue::CreateInt64(5)); + activation_.InsertValue("var2", CelValue::CreateInt64(3)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kUnknown))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kFalse))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsError()); + EXPECT_EQ(response.ErrorOrDie()->code(), absl::StatusCode::kUnavailable); +} + +TEST_F(UnknownsTest, UnknownFunctions) { + PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); + activation_.InsertValue("var1", CelValue::CreateInt64(5)); + activation_.InsertValue("var2", CelValue::CreateInt64(5)); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kUnknown))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kFalse))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + EXPECT_THAT(response.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(FunctionCallIs("F1", "arg1"))); +} + +TEST_F(UnknownsTest, UnknownsMerge) { + PrepareBuilder(UnknownProcessingOptions::kAttributeAndFunction); + activation_.InsertValue("var1", CelValue::CreateInt64(5)); + activation_.set_unknown_attribute_patterns({CelAttributePattern("var2", {})}); + + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F1", FunctionResponse::kUnknown))); + ASSERT_OK(activation_.InsertFunction( + std::make_unique("F2", FunctionResponse::kTrue))); + + // var1 > 3 && F1('arg1') || var2 > 3 && F2('arg2') + auto plan = builder_->CreateExpression(&expr_, nullptr); + ASSERT_OK(plan); + + auto maybe_response = plan.ValueOrDie()->Evaluate(activation_, &arena_); + ASSERT_OK(maybe_response); + CelValue response = maybe_response.ValueOrDie(); + + ASSERT_TRUE(response.IsUnknownSet()) << response.ErrorOrDie()->ToString(); + EXPECT_THAT(response.UnknownSetOrDie() + ->unknown_function_results() + .unknown_function_results(), + ElementsAre(FunctionCallIs("F1", "arg1"))); + EXPECT_THAT(response.UnknownSetOrDie()->unknown_attributes().attributes(), + ElementsAre(AttributeIs("var2"))); +} + +} // namespace +} // namespace runtime +} // namespace expr +} // namespace api +} // namespace google diff --git a/internal/cel_printer.h b/internal/cel_printer.h index d86729a07..11a696390 100644 --- a/internal/cel_printer.h +++ b/internal/cel_printer.h @@ -35,9 +35,7 @@ struct RawString { */ struct ScalarPrinter { inline std::string operator()(std::nullptr_t) { return "null"; } - inline std::string operator()(bool value) { - return value ? "true" : "false"; - } + inline std::string operator()(bool value) { return value ? "true" : "false"; } std::string operator()(absl::Time value); std::string operator()(absl::Duration value); @@ -83,8 +81,8 @@ struct ForwardingPrinter { // If the value defines a ToDebugString function, call it. template - specialize_ifd().ToDebugString())> operator()( - T&& value) { + specialize_ifd().ToDebugString())> + operator()(T&& value) { return std::string(value.ToDebugString()); } }; diff --git a/internal/hash_util.cc b/internal/hash_util.cc index c3b914bf7..c44bd347e 100644 --- a/internal/hash_util.cc +++ b/internal/hash_util.cc @@ -7,7 +7,9 @@ namespace api { namespace expr { namespace internal { -std::size_t HashImpl(const std::string& value, specialize) { return StdHash(value); } +std::size_t HashImpl(const std::string& value, specialize) { + return StdHash(value); +} std::size_t HashImpl(const google::rpc::Status& value, specialize) { std::size_t hash = Hash(value.code()); diff --git a/parser/BUILD b/parser/BUILD index 3b634e599..367694156 100644 --- a/parser/BUILD +++ b/parser/BUILD @@ -27,7 +27,8 @@ cc_library( ":cel_cc_parser", ":macro", ":visitor", - "//base:status", + "//base:status_macros", + "//base:statusor", "@antlr4_runtimes//:cpp", "@com_google_absl//absl/types:optional", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", diff --git a/parser/parser.cc b/parser/parser.cc index 20fae9c04..bed72a4ca 100644 --- a/parser/parser.cc +++ b/parser/parser.cc @@ -5,7 +5,6 @@ #include "parser/cel_grammar.inc/cel_grammar/CelParser.h" #include "parser/visitor.h" #include "antlr4-runtime.h" -#include "base/canonical_errors.h" namespace google { namespace api { @@ -47,15 +46,15 @@ cel_base::StatusOr ParseWithMacros(const std::string& expression, try { root = parser.start(); } catch (ParseCancellationException& e) { - return cel_base::CancelledError(e.what()); + return absl::CancelledError(e.what()); } catch (std::exception& e) { - return cel_base::AbortedError(e.what()); + return absl::AbortedError(e.what()); } Expr expr = visitor.visit(root).as(); if (visitor.hasErrored()) { - return cel_base::InvalidArgumentError(visitor.errorMessage()); + return absl::InvalidArgumentError(visitor.errorMessage()); } // root is deleted as part of the parser context diff --git a/parser/parser_test.cc b/parser/parser_test.cc index a61e309d9..1467c95f4 100644 --- a/parser/parser_test.cc +++ b/parser/parser_test.cc @@ -1,5 +1,6 @@ #include "parser/parser.h" +#include #include #include "gmock/gmock.h" @@ -1132,6 +1133,7 @@ class ExpressionTest : public testing::TestWithParam {}; TEST_P(ExpressionTest, Parse) { const TestInfo& test_info = GetParam(); + /* ::testing::internal::ColoredPrintf(::testing::internal::COLOR_GREEN, "[ ]"); ::testing::internal::ColoredPrintf(::testing::internal::COLOR_DEFAULT, @@ -1141,6 +1143,7 @@ TEST_P(ExpressionTest, Parse) { ::testing::internal::ColoredPrintf( ::testing::internal::COLOR_DEFAULT, "%s\n", !test_info.E.empty() ? " (error expected)" : ""); + */ auto result = Parse(test_info.I); if (test_info.E.empty()) { diff --git a/protoutil/type_registry.cc b/protoutil/type_registry.cc index 5e82bb941..ba5c7854a 100644 --- a/protoutil/type_registry.cc +++ b/protoutil/type_registry.cc @@ -100,8 +100,9 @@ class ProtoStrList final : public BaseProtoRefList { common::Value Get(std::size_t index) const override { std::string scratch; - const std::string& value = msg_->GetReflection()->GetRepeatedStringReference( - *msg_, field_, index, &scratch); + const std::string& value = + msg_->GetReflection()->GetRepeatedStringReference(*msg_, field_, index, + &scratch); if (&value == &scratch) { return common::Value::From(value); } diff --git a/testutil/test_data_io.cc b/testutil/test_data_io.cc index 9566b942e..9e486c00d 100644 --- a/testutil/test_data_io.cc +++ b/testutil/test_data_io.cc @@ -67,8 +67,8 @@ std::unique_ptr OpenForWrite( return nullptr; } } -std::string GetTestCaseFileName(absl::string_view dir, absl::string_view test_name, - bool binary) { +std::string GetTestCaseFileName(absl::string_view dir, + absl::string_view test_name, bool binary) { return absl::StrCat(dir, test_name, binary ? kBinaryPbExt : kTextPbExt); } diff --git a/testutil/test_data_util.cc b/testutil/test_data_util.cc index 06d9f8c2f..3a1a10928 100644 --- a/testutil/test_data_util.cc +++ b/testutil/test_data_util.cc @@ -121,7 +121,7 @@ bool rep_as(F value) { template std::string MakeName(absl::string_view type, T&& value, - absl::string_view name = "") { + absl::string_view name = "") { if (name.empty()) { return absl::StrCat(type, "(", value, ")"); }