From 92adcfb94ade7a31eeb9c5595a1609df2f4a1476 Mon Sep 17 00:00:00 2001 From: Muhammad Askri Date: Thu, 17 Jul 2025 11:59:10 -0700 Subject: [PATCH] Adding the decls and checker lib to CEL regex extensions PiperOrigin-RevId: 784263673 --- eval/compiler/flat_expr_builder.h | 2 + extensions/BUILD | 17 +++++- extensions/regex_ext.cc | 90 +++++++++++++++++++++++++++---- extensions/regex_ext.h | 31 +++++++++-- extensions/regex_ext_test.cc | 84 +++++++++++++++++++++++++++-- 5 files changed, 203 insertions(+), 21 deletions(-) diff --git a/eval/compiler/flat_expr_builder.h b/eval/compiler/flat_expr_builder.h index 758865769..50c0bd9b0 100644 --- a/eval/compiler/flat_expr_builder.h +++ b/eval/compiler/flat_expr_builder.h @@ -93,6 +93,8 @@ class FlatExprBuilder { // `optional_type` handling is needed. void enable_optional_types() { enable_optional_types_ = true; } + bool optional_types_enabled() const { return enable_optional_types_; } + private: const cel::TypeProvider& GetTypeProvider() const; diff --git a/extensions/BUILD b/extensions/BUILD index f127e1eed..c448f5366 100644 --- a/extensions/BUILD +++ b/extensions/BUILD @@ -666,13 +666,22 @@ cc_library( srcs = ["regex_ext.cc"], hdrs = ["regex_ext.h"], deps = [ + "//checker:type_checker_builder", + "//checker/internal:builtins_arena", + "//common:decl", + "//common:type", "//common:value", + "//compiler", "//eval/public:cel_function_registry", "//eval/public:cel_options", + "//internal:casts", "//internal:status_macros", "//runtime:function_adapter", "//runtime:function_registry", - "//runtime:runtime_options", + "//runtime:runtime_builder", + "//runtime/internal:runtime_friend_access", + "//runtime/internal:runtime_impl", + "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -688,8 +697,12 @@ cc_test( srcs = ["regex_ext_test.cc"], deps = [ ":regex_ext", + "//checker:standard_library", + "//checker:validation_result", "//common:value", "//common:value_testing", + "//compiler", + "//compiler:compiler_factory", "//extensions/protobuf:runtime_adapter", "//internal:status_macros", "//internal:testing", @@ -699,12 +712,12 @@ cc_test( "//runtime:activation", "//runtime:optional_types", "//runtime:reference_resolver", - "//runtime:runtime_builder", "//runtime:runtime_options", "//runtime:standard_runtime_builder_factory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", "@com_google_protobuf//:protobuf", ], ) diff --git a/extensions/regex_ext.cc b/extensions/regex_ext.cc index 54cb3e24d..c2766c2c2 100644 --- a/extensions/regex_ext.cc +++ b/extensions/regex_ext.cc @@ -16,22 +16,30 @@ #include #include -#include #include #include +#include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "checker/internal/builtins_arena.h" +#include "checker/type_checker_builder.h" +#include "common/decl.h" +#include "common/type.h" #include "common/value.h" +#include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" +#include "internal/casts.h" #include "internal/status_macros.h" #include "runtime/function_adapter.h" #include "runtime/function_registry.h" -#include "runtime/runtime_options.h" +#include "runtime/internal/runtime_friend_access.h" +#include "runtime/internal/runtime_impl.h" +#include "runtime/runtime_builder.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" #include "google/protobuf/message.h" @@ -40,6 +48,8 @@ namespace cel::extensions { namespace { +using ::cel::checker_internal::BuiltinsArena; + Value Extract(const StringValue& target, const StringValue& regex, const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool, google::protobuf::MessageFactory* ABSL_NONNULL message_factory, @@ -222,8 +232,6 @@ Value ReplaceN(const StringValue& target, const StringValue& regex, return StringValue::From(std::move(output), arena); } -} // namespace - absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) { CEL_RETURN_IF_ERROR( (BinaryFunctionAdapter, StringValue, StringValue>:: @@ -244,10 +252,61 @@ absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) { return absl::OkStatus(); } -absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options) { - if (options.enable_regex) { - CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions(registry)); +const Type& OptionalStringType() { + static absl::NoDestructor kInstance( + OptionalType(BuiltinsArena(), StringType())); + return *kInstance; +} + +const Type& ListStringType() { + static absl::NoDestructor kInstance( + ListType(BuiltinsArena(), StringType())); + return *kInstance; +} + +absl::Status RegisterRegexCheckerDecls(TypeCheckerBuilder& builder) { + CEL_ASSIGN_OR_RETURN( + FunctionDecl extract_decl, + MakeFunctionDecl( + "regex.extract", + MakeOverloadDecl("regex_extract_string_string", OptionalStringType(), + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl extract_all_decl, + MakeFunctionDecl( + "regex.extractAll", + MakeOverloadDecl("regex_extractAll_string_string", ListStringType(), + StringType(), StringType()))); + + CEL_ASSIGN_OR_RETURN( + FunctionDecl replace_decl, + MakeFunctionDecl( + "regex.replace", + MakeOverloadDecl("regex_replace_string_string_string", StringType(), + StringType(), StringType(), StringType()), + MakeOverloadDecl("regex_replace_string_string_string_int", + StringType(), StringType(), StringType(), + StringType(), IntType()))); + + CEL_RETURN_IF_ERROR(builder.AddFunction(extract_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(extract_all_decl)); + CEL_RETURN_IF_ERROR(builder.AddFunction(replace_decl)); + return absl::OkStatus(); +} + +} // namespace + +absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) { + auto& runtime = cel::internal::down_cast( + runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder)); + if (!runtime.expr_builder().optional_types_enabled()) { + return absl::InvalidArgumentError( + "regex extensions requires the optional types to be enabled"); + } + if (runtime.expr_builder().options().enable_regex) { + CEL_RETURN_IF_ERROR( + RegisterRegexExtensionFunctions(builder.function_registry())); } return absl::OkStatus(); } @@ -255,9 +314,18 @@ absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, absl::Status RegisterRegexExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options) { - return RegisterRegexExtensionFunctions( - registry->InternalGetRegistry(), - google::api::expr::runtime::ConvertToRuntimeOptions(options)); + if (!options.enable_regex) { + return RegisterRegexExtensionFunctions(registry->InternalGetRegistry()); + } + return absl::OkStatus(); +} + +CheckerLibrary RegexExtCheckerLibrary() { + return {.id = "cel.lib.ext.regex", .configure = RegisterRegexCheckerDecls}; +} + +CompilerLibrary RegexExtCompilerLibrary() { + return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary()); } } // namespace cel::extensions diff --git a/extensions/regex_ext.h b/extensions/regex_ext.h index 29018779b..b5da5c588 100644 --- a/extensions/regex_ext.h +++ b/extensions/regex_ext.h @@ -76,10 +76,11 @@ #define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ #include "absl/status/status.h" +#include "checker/type_checker_builder.h" +#include "compiler/compiler.h" #include "eval/public/cel_function_registry.h" #include "eval/public/cel_options.h" -#include "runtime/function_registry.h" -#include "runtime/runtime_options.h" +#include "runtime/runtime_builder.h" namespace cel::extensions { @@ -87,8 +88,30 @@ namespace cel::extensions { absl::Status RegisterRegexExtensionFunctions( google::api::expr::runtime::CelFunctionRegistry* registry, const google::api::expr::runtime::InterpreterOptions& options); -absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry, - const RuntimeOptions& options); +absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder); + +// Type check declarations for the regex extension library. +// Provides decls for the following functions: +// +// regex.replace(target: str, pattern: str, replacement: str) -> str +// +// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str +// +// regex.extract(target: str, pattern: str) -> optional +// +// regex.extractAll(target: str, pattern: str) -> list +CheckerLibrary RegexExtCheckerLibrary(); + +// Provides decls for the following functions: +// +// regex.replace(target: str, pattern: str, replacement: str) -> str +// +// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str +// +// regex.extract(target: str, pattern: str) -> optional +// +// regex.extractAll(target: str, pattern: str) -> list +CompilerLibrary RegexExtCompilerLibrary(); } // namespace cel::extensions #endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_ diff --git a/extensions/regex_ext_test.cc b/extensions/regex_ext_test.cc index c626045ea..42971e880 100644 --- a/extensions/regex_ext_test.cc +++ b/extensions/regex_ext_test.cc @@ -22,8 +22,13 @@ #include "absl/status/status.h" #include "absl/status/status_matchers.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "checker/standard_library.h" +#include "checker/validation_result.h" #include "common/value.h" #include "common/value_testing.h" +#include "compiler/compiler.h" +#include "compiler/compiler_factory.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/status_macros.h" #include "internal/testing.h" @@ -33,7 +38,6 @@ #include "runtime/optional_types.h" #include "runtime/reference_resolver.h" #include "runtime/runtime.h" -#include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" #include "google/protobuf/arena.h" @@ -84,9 +88,7 @@ class RegexExtTest : public TestWithParam { EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), IsOk()); ASSERT_THAT(EnableOptionalTypes(builder), IsOk()); - ASSERT_THAT( - RegisterRegexExtensionFunctions(builder.function_registry(), options), - IsOk()); + ASSERT_THAT(RegisterRegexExtensionFunctions(builder), IsOk()); ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build()); } @@ -103,6 +105,23 @@ class RegexExtTest : public TestWithParam { std::unique_ptr runtime_; }; +TEST_F(RegexExtTest, BuildFailsWithoutOptionalSupport) { + RuntimeOptions options; + options.enable_regex = true; + options.enable_qualified_type_identifiers = true; + + ASSERT_OK_AND_ASSIGN(auto builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + ASSERT_THAT( + EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways), + IsOk()); + // Optional types are NOT enabled. + ASSERT_THAT(RegisterRegexExtensionFunctions(builder), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("regex extensions requires the optional types " + "to be enabled"))); +} std::vector regexTestCases() { return { // Tests for extract Function @@ -121,6 +140,11 @@ std::vector regexTestCases() { "regex.extract('hello world', 'goodbye (.*)')"}, {EvaluationType::kOptionalNone, "regex.extract('HELLO', 'hello')"}, {EvaluationType::kOptionalNone, R"(regex.extract('', r'\w+'))"}, + {EvaluationType::kBoolTrue, + "regex.extract('4122345432', '22').orValue('777') == '22'"}, + {EvaluationType::kBoolTrue, + "regex.extract('4122345432', '22').or(optional.of('777')) == " + "optional.of('22')"}, // Tests for extractAll Function {EvaluationType::kBoolTrue, @@ -328,5 +352,57 @@ TEST_P(RegexExtTest, RegexExtTests) { INSTANTIATE_TEST_SUITE_P(RegexExtTest, RegexExtTest, ValuesIn(regexTestCases())); + +struct RegexCheckerTestCase { + std::string expr_string; + std::string error_substr; +}; + +class RegexExtCheckerLibraryTest : public TestWithParam { + public: + void SetUp() override { + // Arrange: Configure the compiler. + // Add the regex checker library to the compiler builder. + ASSERT_OK_AND_ASSIGN(std::unique_ptr compiler_builder, + NewCompilerBuilder(descriptor_pool_)); + ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk()); + ASSERT_THAT(compiler_builder->AddLibrary(RegexExtCompilerLibrary()), + IsOk()); + ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build()); + } + + const google::protobuf::DescriptorPool* descriptor_pool_ = + internal::GetTestingDescriptorPool(); + std::unique_ptr compiler_; +}; + +TEST_P(RegexExtCheckerLibraryTest, RegexExtTypeCheckerTests) { + // Act & Assert: Compile the expression and validate the result. + ASSERT_OK_AND_ASSIGN(ValidationResult result, + compiler_->Compile(GetParam().expr_string)); + absl::string_view error_substr = GetParam().error_substr; + EXPECT_EQ(result.IsValid(), error_substr.empty()); + + if (!error_substr.empty()) { + EXPECT_THAT(result.FormatError(), HasSubstr(error_substr)); + } +} + +std::vector createRegexCheckerParams() { + return { + {"regex.replace('abc', 'a', 's') == 'sbc'"}, + {"regex.replace('abc', 'a', 's') == 121", + "found no matching overload for '_==_' applied to '(string, int)"}, + {"regex.replace('abc', 'j', '1', 2) == 9.0", + "found no matching overload for '_==_' applied to '(string, double)"}, + {"regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"}, + {"regex.extract('foo bar', 'f') == 121", + "found no matching overload for '_==_' applied to " + "'(optional_type(string), int)'"}, + }; +} + +INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest, + ValuesIn(createRegexCheckerParams())); } // namespace } // namespace cel::extensions