diff --git a/testing/testrunner/BUILD b/testing/testrunner/BUILD index 3b1e2f552..975b5884d 100644 --- a/testing/testrunner/BUILD +++ b/testing/testrunner/BUILD @@ -17,8 +17,10 @@ cc_library( "//eval/public:cel_expression", "//runtime", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", + "@com_google_cel_spec//proto/cel/expr:value_cc_proto", ], ) @@ -96,6 +98,7 @@ cc_test( "//runtime", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/log:absl_check", "@com_google_absl//absl/status:status_matchers", diff --git a/testing/testrunner/cel_test_context.h b/testing/testrunner/cel_test_context.h index 903d2a8ea..335f25aa4 100644 --- a/testing/testrunner/cel_test_context.h +++ b/testing/testrunner/cel_test_context.h @@ -17,10 +17,13 @@ #include #include +#include #include #include "cel/expr/checked.pb.h" +#include "cel/expr/value.pb.h" #include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "compiler/compiler.h" #include "eval/public/cel_expression.h" @@ -37,6 +40,16 @@ struct CelTestContextOptions { // input or output values are themselves CEL expressions that need to be // resolved at runtime or cel expression source is raw string or cel file. std::unique_ptr compiler = nullptr; + + // A map of variable names to values that provides default bindings for the + // evaluation. + // + // These bindings can be considered context-wide defaults. If a variable name + // exists in both these custom bindings and in a specific TestCase's input, + // the value from the TestCase will take precedence and override this one. + // This logic is handled by the test runner when it constructs the final + // activation. + absl::flat_hash_map custom_bindings; }; // The context class for a CEL test, holding configurations needed to evaluate @@ -97,6 +110,11 @@ class CelTestContext { : nullptr; } + const absl::flat_hash_map& + custom_bindings() const { + return cel_test_context_options_.custom_bindings; + } + private: // Delete copy and move constructors. CelTestContext(const CelTestContext&) = delete; diff --git a/testing/testrunner/runner_lib.cc b/testing/testrunner/runner_lib.cc index 00232c9c8..2b0375f91 100644 --- a/testing/testrunner/runner_lib.cc +++ b/testing/testrunner/runner_lib.cc @@ -56,11 +56,12 @@ using ::cel::expr::conformance::test::TestCase; using ::cel::expr::conformance::test::TestOutput; using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::CelExpression; -using ::google::api::expr::runtime::CelValue; using ::google::api::expr::runtime::ValueToCelValue; -using ValueProto = ::cel::expr::Value; using ::google::api::expr::runtime::Activation; +using LegacyCelValue = ::google::api::expr::runtime::CelValue; +using ValueProto = ::cel::expr::Value; + absl::StatusOr ReadFileToString(absl::string_view file_path) { std::ifstream file_stream{std::string(file_path)}; if (!file_stream.is_open()) { @@ -131,7 +132,7 @@ absl::StatusOr EvalWithLegacyBindings( CEL_ASSIGN_OR_RETURN(std::unique_ptr sub_expression, builder->CreateExpression(&checked_expr)); - CEL_ASSIGN_OR_RETURN(CelValue legacy_result, + CEL_ASSIGN_OR_RETURN(LegacyCelValue legacy_result, sub_expression->Evaluate(activation, arena)); ValueProto result_proto; @@ -177,23 +178,60 @@ absl::StatusOr ResolveInputValue(const InputValue& input_value, } } -absl::StatusOr CreateModernActivationFromBindings( +absl::Status AddCustomBindingsToModernActivation(const CelTestContext& context, + cel::Activation& activation, + google::protobuf::Arena* arena) { + for (const auto& binding : context.custom_bindings()) { + CEL_ASSIGN_OR_RETURN(cel::Value value, + FromExprValue(/*value_proto=*/binding.second, + GetDescriptorPool(context), + GetMessageFactory(context), arena)); + activation.InsertOrAssignValue(/*name=*/binding.first, value); + } + return absl::OkStatus(); +} + +absl::Status AddTestCaseBindingsToModernActivation( const TestCase& test_case, const CelTestContext& context, - google::protobuf::Arena* arena) { - cel::Activation activation; + cel::Activation& activation, google::protobuf::Arena* arena) { for (const auto& binding : test_case.input()) { CEL_ASSIGN_OR_RETURN( - Value value, + cel::Value value, ResolveInputValue(/*input_value=*/binding.second, context, arena)); activation.InsertOrAssignValue(/*name=*/binding.first, std::move(value)); } - return activation; + return absl::OkStatus(); } -absl::StatusOr CreateLegacyActivationFromBindings( +absl::StatusOr CreateModernActivationFromBindings( const TestCase& test_case, const CelTestContext& context, google::protobuf::Arena* arena) { - Activation activation; + cel::Activation activation; + + CEL_RETURN_IF_ERROR( + AddCustomBindingsToModernActivation(context, activation, arena)); + + CEL_RETURN_IF_ERROR(AddTestCaseBindingsToModernActivation(test_case, context, + activation, arena)); + + return activation; +} + +absl::Status AddCustomBindingsToLegacyActivation(const CelTestContext& context, + Activation& activation, + google::protobuf::Arena* arena) { + for (const auto& binding : context.custom_bindings()) { + CEL_ASSIGN_OR_RETURN( + LegacyCelValue value, + ValueToCelValue(/*value_proto=*/binding.second, arena)); + activation.InsertValue(/*name=*/binding.first, value); + } + return absl::OkStatus(); +} + +absl::Status AddTestCaseBindingsToLegacyActivation( + const TestCase& test_case, const CelTestContext& context, + Activation& activation, google::protobuf::Arena* arena) { auto* message_factory = GetMessageFactory(context); auto* descriptor_pool = GetDescriptorPool(context); for (const auto& binding : test_case.input()) { @@ -203,9 +241,24 @@ absl::StatusOr CreateLegacyActivationFromBindings( CEL_ASSIGN_OR_RETURN(ValueProto value_proto, ToExprValue(resolved_cel_value, descriptor_pool, message_factory, arena)); - CEL_ASSIGN_OR_RETURN(CelValue value, ValueToCelValue(value_proto, arena)); + CEL_ASSIGN_OR_RETURN(LegacyCelValue value, + ValueToCelValue(value_proto, arena)); activation.InsertValue(/*name=*/binding.first, value); } + return absl::OkStatus(); +} + +absl::StatusOr CreateLegacyActivationFromBindings( + const TestCase& test_case, const CelTestContext& context, + google::protobuf::Arena* arena) { + Activation activation; + + CEL_RETURN_IF_ERROR( + AddCustomBindingsToLegacyActivation(context, activation, arena)); + + CEL_RETURN_IF_ERROR(AddTestCaseBindingsToLegacyActivation(test_case, context, + activation, arena)); + return activation; } diff --git a/testing/testrunner/runner_lib_test.cc b/testing/testrunner/runner_lib_test.cc index c95097f21..f63952b2c 100644 --- a/testing/testrunner/runner_lib_test.cc +++ b/testing/testrunner/runner_lib_test.cc @@ -18,6 +18,7 @@ #include #include "gtest/gtest-spi.h" +#include "absl/container/flat_hash_map.h" #include "absl/flags/flag.h" #include "absl/log/absl_check.h" #include "absl/status/status_matchers.h" @@ -57,6 +58,7 @@ using ::cel::expr::conformance::proto3::TestAllTypes; using ::cel::expr::conformance::test::TestCase; using ::cel::expr::CheckedExpr; using ::google::api::expr::runtime::CelExpressionBuilder; +using ValueProto = ::cel::expr::Value; template T ParseTextProtoOrDie(absl::string_view text_proto) { @@ -190,7 +192,8 @@ TEST_P(TestRunnerParamTest, BasicTestReportsFailure) { CelExpressionSource::FromCheckedExpr( std::move(checked_expr))})); TestRunner test_runner(std::move(context)); - EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "bool_value: true"); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "bool_value: true"); // expected true got false } TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsSuccess) { @@ -248,7 +251,8 @@ TEST_P(TestRunnerParamTest, DynamicInputAndOutputReportsFailure) { std::move(checked_expr)), .compiler = std::move(compiler)})); TestRunner test_runner(std::move(context)); - EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 5"); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 5"); // expected 5 got 10 } TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsSuccess) { @@ -296,7 +300,8 @@ TEST_P(TestRunnerParamTest, RawExpressionWithCompilerReportsFailure) { CelExpressionSource::FromRawExpression("x - y"), .compiler = std::move(compiler)})); TestRunner test_runner(std::move(context)); - EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 7"); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 7"); // expected 7 got 100 } TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsSuccess) { @@ -350,7 +355,67 @@ TEST_P(TestRunnerParamTest, CelFileWithCompilerReportsFailure) { CelExpressionSource::FromCelFile(cel_file_path), .compiler = std::move(compiler)})); TestRunner test_runner(std::move(context)); - EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), "int64_value: 7"); + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 7"); // expected 7 got 123 +} + +TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsSucceeds) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + output { result_value { int64_value: 15 } } + )pb"); + + absl::flat_hash_map bindings; + bindings["y"] = ParseTextProtoOrDie(R"pb(int64_value: 5)pb"); + + ASSERT_OK_AND_ASSIGN( + auto context, CreateTestContext( + /*options=*/{.expression_source = + CelExpressionSource::FromCheckedExpr( + std::move(checked_expr)), + .custom_bindings = std::move(bindings)})); + TestRunner test_runner(std::move(context)); + + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + +TEST_P(TestRunnerParamTest, BasicTestWithCustomBindingsReportsFailure) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 10 } } + } + output { result_value { int64_value: 999 } } + )pb"); + + absl::flat_hash_map bindings; + bindings["y"] = ParseTextProtoOrDie(R"pb(int64_value: 5)pb"); + + ASSERT_OK_AND_ASSIGN( + auto context, CreateTestContext( + /*options=*/{.expression_source = + CelExpressionSource::FromCheckedExpr( + std::move(checked_expr)), + .custom_bindings = std::move(bindings)})); + TestRunner test_runner(std::move(context)); + + EXPECT_NONFATAL_FAILURE(test_runner.RunTest(test_case), + "int64_value: 15"); // expected 15 got 999. } INSTANTIATE_TEST_SUITE_P(TestRunnerTests, TestRunnerParamTest,