diff --git a/testing/testrunner/BUILD b/testing/testrunner/BUILD index 536adaf4d..10b906028 100644 --- a/testing/testrunner/BUILD +++ b/testing/testrunner/BUILD @@ -13,6 +13,7 @@ cc_library( hdrs = ["cel_test_context.h"], deps = [ ":cel_expression_source", + "//common:value", "//compiler", "//eval/public:cel_expression", "//runtime", @@ -20,7 +21,6 @@ cc_library( "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", - "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_cel_spec//proto/cel/expr:checked_cc_proto", "@com_google_cel_spec//proto/cel/expr:value_cc_proto", diff --git a/testing/testrunner/cel_test_context.h b/testing/testrunner/cel_test_context.h index 176cc18b1..aa6aab3ac 100644 --- a/testing/testrunner/cel_test_context.h +++ b/testing/testrunner/cel_test_context.h @@ -26,6 +26,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" #include "absl/status/statusor.h" +#include "common/value.h" #include "compiler/compiler.h" #include "eval/public/cel_expression.h" #include "runtime/activation.h" @@ -42,6 +43,10 @@ class CelTestContext { using CelActivationFactoryFn = std::function( const cel::expr::conformance::test::TestCase& test_case, google::protobuf::Arena* arena)>; + using AssertFn = std::function; // Creates a CelTestContext using a `CelExpressionBuilder`. // @@ -127,6 +132,12 @@ class CelTestContext { return activation_factory_; } + // Allows the runner to inject a custom assertion function. If not set, the + // default assertion logic in TestRunner will be used. + void SetAssertFn(AssertFn assert_fn) { assert_fn_ = std::move(assert_fn); } + + const AssertFn& assert_fn() const { return assert_fn_; } + private: // Delete copy and move constructors. CelTestContext(const CelTestContext&) = delete; @@ -173,6 +184,7 @@ class CelTestContext { std::unique_ptr runtime_; CelActivationFactoryFn activation_factory_; + AssertFn assert_fn_; }; } // namespace cel::test diff --git a/testing/testrunner/runner_lib.cc b/testing/testrunner/runner_lib.cc index b6279ab3c..ae09de255 100644 --- a/testing/testrunner/runner_lib.cc +++ b/testing/testrunner/runner_lib.cc @@ -340,6 +340,10 @@ void TestRunner::AssertError(const cel::Value& computed, void TestRunner::Assert(const cel::Value& computed, const TestCase& test_case, google::protobuf::Arena* arena) { + if (test_context_->assert_fn()) { + test_context_->assert_fn()(computed, test_case, arena); + return; + } TestOutput output = test_case.output(); if (output.has_result_value() || output.has_result_expr()) { AssertValue(computed, output, arena); diff --git a/testing/testrunner/runner_lib_test.cc b/testing/testrunner/runner_lib_test.cc index 0ddb62ca7..804826b6c 100644 --- a/testing/testrunner/runner_lib_test.cc +++ b/testing/testrunner/runner_lib_test.cc @@ -652,6 +652,35 @@ TEST(TestRunnerStandaloneTest, BasicTestWithActivationFactorySucceeds) { EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); } +TEST(TestRunnerStandaloneTest, CustomAssertFnIsUsed) { + // Compile the expression. + ASSERT_OK_AND_ASSIGN(cel::ValidationResult validation_result, + DefaultCompiler().Compile("1 + 1")); + CheckedExpr checked_expr; + ASSERT_THAT(cel::AstToCheckedExpr(*validation_result.GetAst(), &checked_expr), + absl_testing::IsOk()); + // Create a runtime. + ASSERT_OK_AND_ASSIGN(std::unique_ptr runtime, + CreateTestRuntime()); + // Set the output to a value that would fail the default assertion. + TestCase test_case = ParseTextProtoOrDie(R"pb( + output { result_value { int64_value: 102 } } + )pb"); + std::unique_ptr context = + CelTestContext::CreateFromRuntime(std::move(runtime)); + + context->SetAssertFn([&](const cel::Value& computed, + const TestCase& test_case, google::protobuf::Arena* arena) { + ASSERT_TRUE(computed.Is()); + EXPECT_EQ(computed.As().value(), 2); + }); + + context->SetExpressionSource( + CelExpressionSource::FromCheckedExpr(std::move(checked_expr))); + TestRunner test_runner(std::move(context)); + EXPECT_NO_FATAL_FAILURE(test_runner.RunTest(test_case)); +} + TEST(CoverageTest, RuntimeCoverage) { ASSERT_OK_AND_ASSIGN( std::unique_ptr compiler_builder,