diff --git a/conformance/BUILD b/conformance/BUILD index a52f56019..2b3d92bfa 100644 --- a/conformance/BUILD +++ b/conformance/BUILD @@ -119,6 +119,7 @@ cc_library( srcs = ["run.cc"], deps = [ ":service", + ":utils", "//internal:testing_no_main", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", @@ -138,6 +139,20 @@ cc_library( alwayslink = True, ) +cc_library( + name = "utils", + testonly = True, + hdrs = ["utils.h"], + deps = [ + "//internal:testing", + "@com_google_absl//absl/log:absl_check", + "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) + _ALL_TESTS = [ "@com_google_cel_spec//tests/simple:testdata/basic.textproto", "@com_google_cel_spec//tests/simple:testdata/bindings_ext.textproto", diff --git a/conformance/run.cc b/conformance/run.cc index d61e24fb9..ac6151671 100644 --- a/conformance/run.cc +++ b/conformance/run.cc @@ -19,7 +19,6 @@ // conformance tests; as well as integrating better with C++ testing // infrastructure. -#include #include #include #include @@ -47,13 +46,12 @@ #include "absl/strings/strip.h" #include "absl/types/span.h" #include "conformance/service.h" +#include "conformance/utils.h" #include "internal/testing.h" #include "cel/expr/conformance/test/simple.pb.h" #include "google/protobuf/io/zero_copy_stream_impl.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" -#include "google/protobuf/util/field_comparator.h" -#include "google/protobuf/util/message_differencer.h" ABSL_FLAG(bool, opt, false, "Enable optimizations (constant folding)"); ABSL_FLAG( @@ -78,93 +76,11 @@ using google::api::expr::conformance::v1alpha1::EvalRequest; using google::api::expr::conformance::v1alpha1::EvalResponse; using google::api::expr::conformance::v1alpha1::ParseRequest; using google::api::expr::conformance::v1alpha1::ParseResponse; -using google::protobuf::TextFormat; -using google::protobuf::util::DefaultFieldComparator; -using google::protobuf::util::MessageDifferencer; google::rpc::Code ToGrpcCode(absl::StatusCode code) { return static_cast(code); } -std::string DescribeMessage(const google::protobuf::Message& message) { - std::string string; - ABSL_CHECK(TextFormat::PrintToString(message, &string)); - if (string.empty()) { - string = "\"\"\n"; - } - return string; -} - -MATCHER_P(MatchesConformanceValue, expected, "") { - static auto* kFieldComparator = []() { - auto* field_comparator = new DefaultFieldComparator(); - field_comparator->set_treat_nan_as_equal(true); - return field_comparator; - }(); - static auto* kDifferencer = []() { - auto* differencer = new MessageDifferencer(); - differencer->set_message_field_comparison(MessageDifferencer::EQUIVALENT); - differencer->set_field_comparator(kFieldComparator); - const auto* descriptor = cel::expr::MapValue::descriptor(); - const auto* entries_field = descriptor->FindFieldByName("entries"); - const auto* key_field = - entries_field->message_type()->FindFieldByName("key"); - differencer->TreatAsMap(entries_field, key_field); - return differencer; - }(); - - const cel::expr::ExprValue& got = arg; - const cel::expr::Value& want = expected; - - cel::expr::ExprValue test_value; - (*test_value.mutable_value()) = want; - - if (kDifferencer->Compare(got, test_value)) { - return true; - } - (*result_listener) << "got: " << DescribeMessage(got); - (*result_listener) << "\n"; - (*result_listener) << "wanted: " << DescribeMessage(test_value); - return false; -} - -MATCHER_P(ResultTypeMatches, expected, "") { - static auto* kDifferencer = []() { - auto* differencer = new MessageDifferencer(); - differencer->set_message_field_comparison(MessageDifferencer::EQUIVALENT); - return differencer; - }(); - - const cel::expr::Type& want = expected; - const google::api::expr::v1alpha1::CheckedExpr& checked_expr = arg; - - int64_t root_id = checked_expr.expr().id(); - auto it = checked_expr.type_map().find(root_id); - - if (it == checked_expr.type_map().end()) { - (*result_listener) << "type map does not contain root id: " << root_id; - return false; - } - - auto got_versioned = it->second; - std::string serialized; - cel::expr::Type got; - if (!got_versioned.SerializeToString(&serialized) || - !got.ParseFromString(serialized)) { - (*result_listener) << "type cannot be converted from versioned type: " - << DescribeMessage(got_versioned); - return false; - } - - if (kDifferencer->Compare(got, want)) { - return true; - } - (*result_listener) << "got: " << DescribeMessage(got); - (*result_listener) << "\n"; - (*result_listener) << "wanted: " << DescribeMessage(want); - return false; -} - bool ShouldSkipTest(absl::Span tests_to_skip, absl::string_view name) { for (absl::string_view test_to_skip : tests_to_skip) { @@ -245,7 +161,8 @@ class ConformanceTest : public testing::Test { ASSERT_TRUE(test_.has_typed_result()) << "test must specify a typed result if check_only is set"; EXPECT_THAT(eval_request.checked_expr(), - ResultTypeMatches(test_.typed_result().deduced_type())); + cel_conformance::ResultTypeMatches( + test_.typed_result().deduced_type())); return; } @@ -263,7 +180,8 @@ class ConformanceTest : public testing::Test { ABSL_CHECK(eval_response.result().SerializePartialToCord(&serialized)); cel::expr::ExprValue test_value; ABSL_CHECK(test_value.ParsePartialFromCord(serialized)); - EXPECT_THAT(test_value, MatchesConformanceValue(test_.value())); + EXPECT_THAT(test_value, + cel_conformance::MatchesConformanceValue(test_.value())); break; } case SimpleTest::kTypedResult: { @@ -273,10 +191,11 @@ class ConformanceTest : public testing::Test { ABSL_CHECK(eval_response.result().SerializePartialToCord(&serialized)); cel::expr::ExprValue test_value; ABSL_CHECK(test_value.ParsePartialFromCord(serialized)); - EXPECT_THAT(test_value, - MatchesConformanceValue(test_.typed_result().result())); + EXPECT_THAT(test_value, cel_conformance::MatchesConformanceValue( + test_.typed_result().result())); EXPECT_THAT(eval_request.checked_expr(), - ResultTypeMatches(test_.typed_result().deduced_type())); + cel_conformance::ResultTypeMatches( + test_.typed_result().deduced_type())); break; } case SimpleTest::kEvalError: diff --git a/conformance/utils.h b/conformance/utils.h new file mode 100644 index 000000000..e01114125 --- /dev/null +++ b/conformance/utils.h @@ -0,0 +1,118 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ +#define THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_ + +#include +#include + +#include "cel/expr/checked.pb.h" +#include "cel/expr/eval.pb.h" +#include "google/api/expr/v1alpha1/checked.pb.h" +#include "google/api/expr/v1alpha1/syntax.pb.h" +#include "cel/expr/value.pb.h" +#include "absl/log/absl_check.h" +#include "internal/testing.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/util/field_comparator.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel_conformance { + +inline std::string DescribeMessage(const google::protobuf::Message& message) { + std::string string; + ABSL_CHECK(google::protobuf::TextFormat::PrintToString(message, &string)); + if (string.empty()) { + string = "\"\"\n"; + } + return string; +} + +MATCHER_P(MatchesConformanceValue, expected, "") { + static auto* kFieldComparator = []() { + auto* field_comparator = new google::protobuf::util::DefaultFieldComparator(); + field_comparator->set_treat_nan_as_equal(true); + return field_comparator; + }(); + static auto* kDifferencer = []() { + auto* differencer = new google::protobuf::util::MessageDifferencer(); + differencer->set_message_field_comparison( + google::protobuf::util::MessageDifferencer::EQUIVALENT); + differencer->set_field_comparator(kFieldComparator); + const auto* descriptor = cel::expr::MapValue::descriptor(); + const auto* entries_field = descriptor->FindFieldByName("entries"); + const auto* key_field = + entries_field->message_type()->FindFieldByName("key"); + differencer->TreatAsMap(entries_field, key_field); + return differencer; + }(); + + const cel::expr::ExprValue& got = arg; + const cel::expr::Value& want = expected; + + cel::expr::ExprValue test_value; + (*test_value.mutable_value()) = want; + + if (kDifferencer->Compare(got, test_value)) { + return true; + } + (*result_listener) << "got: " << DescribeMessage(got); + (*result_listener) << "\n"; + (*result_listener) << "wanted: " << DescribeMessage(test_value); + return false; +} + +MATCHER_P(ResultTypeMatches, expected, "") { + static auto* kDifferencer = []() { + auto* differencer = new google::protobuf::util::MessageDifferencer(); + differencer->set_message_field_comparison( + google::protobuf::util::MessageDifferencer::EQUIVALENT); + return differencer; + }(); + + const cel::expr::Type& want = expected; + const google::api::expr::v1alpha1::CheckedExpr& checked_expr = arg; + + int64_t root_id = checked_expr.expr().id(); + auto it = checked_expr.type_map().find(root_id); + + if (it == checked_expr.type_map().end()) { + (*result_listener) << "type map does not contain root id: " << root_id; + return false; + } + + auto got_versioned = it->second; + std::string serialized; + cel::expr::Type got; + if (!got_versioned.SerializeToString(&serialized) || + !got.ParseFromString(serialized)) { + (*result_listener) << "type cannot be converted from versioned type: " + << DescribeMessage(got_versioned); + return false; + } + + if (kDifferencer->Compare(got, want)) { + return true; + } + (*result_listener) << "got: " << DescribeMessage(got); + (*result_listener) << "\n"; + (*result_listener) << "wanted: " << DescribeMessage(want); + return false; +} + +} // namespace cel_conformance + +#endif // THIRD_PARTY_CEL_CPP_CONFORMANCE_UTILS_H_