diff --git a/testing/testrunner/BUILD b/testing/testrunner/BUILD index ffc2040f6..536adaf4d 100644 --- a/testing/testrunner/BUILD +++ b/testing/testrunner/BUILD @@ -16,11 +16,16 @@ cc_library( "//compiler", "//eval/public:cel_expression", "//runtime", + "//runtime:activation", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", + "@com_google_cel_spec//proto/cel/expr/conformance/test:suite_cc_proto", + "@com_google_protobuf//:protobuf", ], ) @@ -89,6 +94,7 @@ cc_test( "//common:ast_proto", "//common:decl", "//common:type", + "//common:value", "//compiler", "//compiler:compiler_factory", "//compiler:standard_library", @@ -99,6 +105,7 @@ cc_test( "//internal:testing", "//internal:testing_descriptor_pool", "//runtime", + "//runtime:activation", "//runtime:runtime_builder", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/container:flat_hash_map", diff --git a/testing/testrunner/cel_test_context.h b/testing/testrunner/cel_test_context.h index beaecb4ef..176cc18b1 100644 --- a/testing/testrunner/cel_test_context.h +++ b/testing/testrunner/cel_test_context.h @@ -15,6 +15,7 @@ #ifndef THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ #define THIRD_PARTY_CEL_CPP_TOOLS_TESTRUNNER_CEL_TEST_CONTEXT_H_ +#include #include #include #include @@ -24,16 +25,24 @@ #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" +#include "absl/status/statusor.h" #include "compiler/compiler.h" #include "eval/public/cel_expression.h" +#include "runtime/activation.h" #include "runtime/runtime.h" #include "testing/testrunner/cel_expression_source.h" +#include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" namespace cel::test { // The context class for a CEL test, holding configurations needed to evaluate // compiled CEL expressions. class CelTestContext { public: + using CelActivationFactoryFn = std::function( + const cel::expr::conformance::test::TestCase& test_case, + google::protobuf::Arena* arena)>; + // Creates a CelTestContext using a `CelExpressionBuilder`. // // The `CelExpressionBuilder` helps in setting up the environment for @@ -107,6 +116,17 @@ class CelTestContext { custom_bindings_ = std::move(custom_bindings); } + // Allows the runner to inject a custom activation factory. If not set, an + // empty activation will be used. Custom bindings and test case inputs will + // be added to the activation returned by the factory. + void SetActivationFactory(CelActivationFactoryFn activation_factory) { + activation_factory_ = std::move(activation_factory); + } + + const CelActivationFactoryFn& activation_factory() const { + return activation_factory_; + } + private: // Delete copy and move constructors. CelTestContext(const CelTestContext&) = delete; @@ -151,6 +171,8 @@ class CelTestContext { // needed to generate Program. Users should either provide a runtime, or the // CelExpressionBuilder. std::unique_ptr runtime_; + + CelActivationFactoryFn activation_factory_; }; } // namespace cel::test diff --git a/testing/testrunner/runner_lib.cc b/testing/testrunner/runner_lib.cc index aabbaefb6..b6279ab3c 100644 --- a/testing/testrunner/runner_lib.cc +++ b/testing/testrunner/runner_lib.cc @@ -203,11 +203,20 @@ absl::Status AddTestCaseBindingsToModernActivation( return absl::OkStatus(); } +absl::StatusOr GetActivation(const CelTestContext& context, + const TestCase& test_case, + google::protobuf::Arena* arena) { + if (context.activation_factory() != nullptr) { + return context.activation_factory()(test_case, arena); + } + return cel::Activation(); +} + absl::StatusOr CreateModernActivationFromBindings( const TestCase& test_case, const CelTestContext& context, google::protobuf::Arena* arena) { - cel::Activation activation; - + CEL_ASSIGN_OR_RETURN(cel::Activation activation, + GetActivation(context, test_case, arena)); CEL_RETURN_IF_ERROR( AddCustomBindingsToModernActivation(context, activation, arena)); diff --git a/testing/testrunner/runner_lib_test.cc b/testing/testrunner/runner_lib_test.cc index daf7859a4..0ddb62ca7 100644 --- a/testing/testrunner/runner_lib_test.cc +++ b/testing/testrunner/runner_lib_test.cc @@ -30,6 +30,7 @@ #include "common/ast_proto.h" #include "common/decl.h" #include "common/type.h" +#include "common/value.h" #include "compiler/compiler.h" #include "compiler/compiler_factory.h" #include "compiler/standard_library.h" @@ -39,6 +40,7 @@ #include "internal/status_macros.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" #include "runtime/runtime.h" #include "runtime/runtime_builder.h" #include "runtime/standard_runtime_builder_factory.h" @@ -47,6 +49,7 @@ #include "testing/testrunner/coverage_index.h" #include "cel/expr/conformance/proto3/test_all_types.pb.h" #include "cel/expr/conformance/test/suite.pb.h" +#include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" #include "google/protobuf/text_format.h" @@ -610,6 +613,45 @@ TEST(TestRunnerStandaloneTest, BasicTestFailsWhenExpectingErrorButGotValue) { "Expected error but got value"); } +TEST(TestRunnerStandaloneTest, BasicTestWithActivationFactorySucceeds) { + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("x + y")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + context->SetActivationFactory( + [](const TestCase& test_case, + google::protobuf::Arena* arena) -> absl::StatusOr { + cel::Activation activation; + activation.InsertOrAssignValue("x", cel::IntValue(10)); + activation.InsertOrAssignValue("y", cel::IntValue(5)); + return activation; + }); + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + + TestCase test_case = ParseTextProtoOrDie(R"pb( + output { result_value { int64_value: 15 } } + )pb"); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); + + // Input bindings should override values set by the activation factory. + test_case = ParseTextProtoOrDie(R"pb( + input { + key: "x" + value { value { int64_value: 4 } } + } + output { result_value { int64_value: 9 } } + )pb"); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + TEST(CoverageTest, RuntimeCoverage) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder,