Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eval/compiler/flat_expr_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class FlatExprBuilder {
// `optional_type` handling is needed.
void enable_optional_types() { enable_optional_types_ = true; }

bool optional_types_enabled() const { return enable_optional_types_; }

private:
const cel::TypeProvider& GetTypeProvider() const;

Expand Down
17 changes: 15 additions & 2 deletions extensions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -666,13 +666,22 @@ cc_library(
srcs = ["regex_ext.cc"],
hdrs = ["regex_ext.h"],
deps = [
"//checker:type_checker_builder",
"//checker/internal:builtins_arena",
"//common:decl",
"//common:type",
"//common:value",
"//compiler",
"//eval/public:cel_function_registry",
"//eval/public:cel_options",
"//internal:casts",
"//internal:status_macros",
"//runtime:function_adapter",
"//runtime:function_registry",
"//runtime:runtime_options",
"//runtime:runtime_builder",
"//runtime/internal:runtime_friend_access",
"//runtime/internal:runtime_impl",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand All @@ -688,8 +697,12 @@ cc_test(
srcs = ["regex_ext_test.cc"],
deps = [
":regex_ext",
"//checker:standard_library",
"//checker:validation_result",
"//common:value",
"//common:value_testing",
"//compiler",
"//compiler:compiler_factory",
"//extensions/protobuf:runtime_adapter",
"//internal:status_macros",
"//internal:testing",
Expand All @@ -699,12 +712,12 @@ cc_test(
"//runtime:activation",
"//runtime:optional_types",
"//runtime:reference_resolver",
"//runtime:runtime_builder",
"//runtime:runtime_options",
"//runtime:standard_runtime_builder_factory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:status_matchers",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_protobuf//:protobuf",
],
)
Expand Down
90 changes: 79 additions & 11 deletions extensions/regex_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,30 @@

#include <algorithm>
#include <cstdint>
#include <memory>
#include <string>
#include <utility>

#include "absl/base/no_destructor.h"
#include "absl/base/nullability.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "checker/internal/builtins_arena.h"
#include "checker/type_checker_builder.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/value.h"
#include "compiler/compiler.h"
#include "eval/public/cel_function_registry.h"
#include "eval/public/cel_options.h"
#include "internal/casts.h"
#include "internal/status_macros.h"
#include "runtime/function_adapter.h"
#include "runtime/function_registry.h"
#include "runtime/runtime_options.h"
#include "runtime/internal/runtime_friend_access.h"
#include "runtime/internal/runtime_impl.h"
#include "runtime/runtime_builder.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
Expand All @@ -40,6 +48,8 @@
namespace cel::extensions {
namespace {

using ::cel::checker_internal::BuiltinsArena;

Value Extract(const StringValue& target, const StringValue& regex,
const google::protobuf::DescriptorPool* ABSL_NONNULL descriptor_pool,
google::protobuf::MessageFactory* ABSL_NONNULL message_factory,
Expand Down Expand Up @@ -222,8 +232,6 @@ Value ReplaceN(const StringValue& target, const StringValue& regex,
return StringValue::From(std::move(output), arena);
}

} // namespace

absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) {
CEL_RETURN_IF_ERROR(
(BinaryFunctionAdapter<absl::StatusOr<Value>, StringValue, StringValue>::
Expand All @@ -244,20 +252,80 @@ absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry) {
return absl::OkStatus();
}

absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry,
const RuntimeOptions& options) {
if (options.enable_regex) {
CEL_RETURN_IF_ERROR(RegisterRegexExtensionFunctions(registry));
const Type& OptionalStringType() {
static absl::NoDestructor<Type> kInstance(
OptionalType(BuiltinsArena(), StringType()));
return *kInstance;
}

const Type& ListStringType() {
static absl::NoDestructor<Type> kInstance(
ListType(BuiltinsArena(), StringType()));
return *kInstance;
}

absl::Status RegisterRegexCheckerDecls(TypeCheckerBuilder& builder) {
CEL_ASSIGN_OR_RETURN(
FunctionDecl extract_decl,
MakeFunctionDecl(
"regex.extract",
MakeOverloadDecl("regex_extract_string_string", OptionalStringType(),
StringType(), StringType())));

CEL_ASSIGN_OR_RETURN(
FunctionDecl extract_all_decl,
MakeFunctionDecl(
"regex.extractAll",
MakeOverloadDecl("regex_extractAll_string_string", ListStringType(),
StringType(), StringType())));

CEL_ASSIGN_OR_RETURN(
FunctionDecl replace_decl,
MakeFunctionDecl(
"regex.replace",
MakeOverloadDecl("regex_replace_string_string_string", StringType(),
StringType(), StringType(), StringType()),
MakeOverloadDecl("regex_replace_string_string_string_int",
StringType(), StringType(), StringType(),
StringType(), IntType())));

CEL_RETURN_IF_ERROR(builder.AddFunction(extract_decl));
CEL_RETURN_IF_ERROR(builder.AddFunction(extract_all_decl));
CEL_RETURN_IF_ERROR(builder.AddFunction(replace_decl));
return absl::OkStatus();
}

} // namespace

absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder) {
auto& runtime = cel::internal::down_cast<runtime_internal::RuntimeImpl&>(
runtime_internal::RuntimeFriendAccess::GetMutableRuntime(builder));
if (!runtime.expr_builder().optional_types_enabled()) {
return absl::InvalidArgumentError(
"regex extensions requires the optional types to be enabled");
}
if (runtime.expr_builder().options().enable_regex) {
CEL_RETURN_IF_ERROR(
RegisterRegexExtensionFunctions(builder.function_registry()));
}
return absl::OkStatus();
}

absl::Status RegisterRegexExtensionFunctions(
google::api::expr::runtime::CelFunctionRegistry* registry,
const google::api::expr::runtime::InterpreterOptions& options) {
return RegisterRegexExtensionFunctions(
registry->InternalGetRegistry(),
google::api::expr::runtime::ConvertToRuntimeOptions(options));
if (!options.enable_regex) {
return RegisterRegexExtensionFunctions(registry->InternalGetRegistry());
}
return absl::OkStatus();
}

CheckerLibrary RegexExtCheckerLibrary() {
return {.id = "cel.lib.ext.regex", .configure = RegisterRegexCheckerDecls};
}

CompilerLibrary RegexExtCompilerLibrary() {
return CompilerLibrary::FromCheckerLibrary(RegexExtCheckerLibrary());
}

} // namespace cel::extensions
31 changes: 27 additions & 4 deletions extensions/regex_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,42 @@
#define THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_

#include "absl/status/status.h"
#include "checker/type_checker_builder.h"
#include "compiler/compiler.h"
#include "eval/public/cel_function_registry.h"
#include "eval/public/cel_options.h"
#include "runtime/function_registry.h"
#include "runtime/runtime_options.h"
#include "runtime/runtime_builder.h"

namespace cel::extensions {

// Register extension functions for regular expressions.
absl::Status RegisterRegexExtensionFunctions(
google::api::expr::runtime::CelFunctionRegistry* registry,
const google::api::expr::runtime::InterpreterOptions& options);
absl::Status RegisterRegexExtensionFunctions(FunctionRegistry& registry,
const RuntimeOptions& options);
absl::Status RegisterRegexExtensionFunctions(RuntimeBuilder& builder);

// Type check declarations for the regex extension library.
// Provides decls for the following functions:
//
// regex.replace(target: str, pattern: str, replacement: str) -> str
//
// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str
//
// regex.extract(target: str, pattern: str) -> optional<str>
//
// regex.extractAll(target: str, pattern: str) -> list<str>
CheckerLibrary RegexExtCheckerLibrary();

// Provides decls for the following functions:
//
// regex.replace(target: str, pattern: str, replacement: str) -> str
//
// regex.replace(target: str, pattern: str, replacement: str, count: int) -> str
//
// regex.extract(target: str, pattern: str) -> optional<str>
//
// regex.extractAll(target: str, pattern: str) -> list<str>
CompilerLibrary RegexExtCompilerLibrary();

} // namespace cel::extensions
#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_REGEX_EXT_H_
84 changes: 80 additions & 4 deletions extensions/regex_ext_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
#include "absl/status/status.h"
#include "absl/status/status_matchers.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "checker/standard_library.h"
#include "checker/validation_result.h"
#include "common/value.h"
#include "common/value_testing.h"
#include "compiler/compiler.h"
#include "compiler/compiler_factory.h"
#include "extensions/protobuf/runtime_adapter.h"
#include "internal/status_macros.h"
#include "internal/testing.h"
Expand All @@ -33,7 +38,6 @@
#include "runtime/optional_types.h"
#include "runtime/reference_resolver.h"
#include "runtime/runtime.h"
#include "runtime/runtime_builder.h"
#include "runtime/runtime_options.h"
#include "runtime/standard_runtime_builder_factory.h"
#include "google/protobuf/arena.h"
Expand Down Expand Up @@ -84,9 +88,7 @@ class RegexExtTest : public TestWithParam<RegexExtTestCase> {
EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways),
IsOk());
ASSERT_THAT(EnableOptionalTypes(builder), IsOk());
ASSERT_THAT(
RegisterRegexExtensionFunctions(builder.function_registry(), options),
IsOk());
ASSERT_THAT(RegisterRegexExtensionFunctions(builder), IsOk());
ASSERT_OK_AND_ASSIGN(runtime_, std::move(builder).Build());
}

Expand All @@ -103,6 +105,23 @@ class RegexExtTest : public TestWithParam<RegexExtTestCase> {
std::unique_ptr<const Runtime> runtime_;
};

TEST_F(RegexExtTest, BuildFailsWithoutOptionalSupport) {
RuntimeOptions options;
options.enable_regex = true;
options.enable_qualified_type_identifiers = true;

ASSERT_OK_AND_ASSIGN(auto builder,
CreateStandardRuntimeBuilder(
internal::GetTestingDescriptorPool(), options));
ASSERT_THAT(
EnableReferenceResolver(builder, ReferenceResolverEnabled::kAlways),
IsOk());
// Optional types are NOT enabled.
ASSERT_THAT(RegisterRegexExtensionFunctions(builder),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("regex extensions requires the optional types "
"to be enabled")));
}
std::vector<RegexExtTestCase> regexTestCases() {
return {
// Tests for extract Function
Expand All @@ -121,6 +140,11 @@ std::vector<RegexExtTestCase> regexTestCases() {
"regex.extract('hello world', 'goodbye (.*)')"},
{EvaluationType::kOptionalNone, "regex.extract('HELLO', 'hello')"},
{EvaluationType::kOptionalNone, R"(regex.extract('', r'\w+'))"},
{EvaluationType::kBoolTrue,
"regex.extract('4122345432', '22').orValue('777') == '22'"},
{EvaluationType::kBoolTrue,
"regex.extract('4122345432', '22').or(optional.of('777')) == "
"optional.of('22')"},

// Tests for extractAll Function
{EvaluationType::kBoolTrue,
Expand Down Expand Up @@ -328,5 +352,57 @@ TEST_P(RegexExtTest, RegexExtTests) {

INSTANTIATE_TEST_SUITE_P(RegexExtTest, RegexExtTest,
ValuesIn(regexTestCases()));

struct RegexCheckerTestCase {
std::string expr_string;
std::string error_substr;
};

class RegexExtCheckerLibraryTest : public TestWithParam<RegexCheckerTestCase> {
public:
void SetUp() override {
// Arrange: Configure the compiler.
// Add the regex checker library to the compiler builder.
ASSERT_OK_AND_ASSIGN(std::unique_ptr<CompilerBuilder> compiler_builder,
NewCompilerBuilder(descriptor_pool_));
ASSERT_THAT(compiler_builder->AddLibrary(StandardCheckerLibrary()), IsOk());
ASSERT_THAT(compiler_builder->AddLibrary(RegexExtCompilerLibrary()),
IsOk());
ASSERT_OK_AND_ASSIGN(compiler_, std::move(*compiler_builder).Build());
}

const google::protobuf::DescriptorPool* descriptor_pool_ =
internal::GetTestingDescriptorPool();
std::unique_ptr<Compiler> compiler_;
};

TEST_P(RegexExtCheckerLibraryTest, RegexExtTypeCheckerTests) {
// Act & Assert: Compile the expression and validate the result.
ASSERT_OK_AND_ASSIGN(ValidationResult result,
compiler_->Compile(GetParam().expr_string));
absl::string_view error_substr = GetParam().error_substr;
EXPECT_EQ(result.IsValid(), error_substr.empty());

if (!error_substr.empty()) {
EXPECT_THAT(result.FormatError(), HasSubstr(error_substr));
}
}

std::vector<RegexCheckerTestCase> createRegexCheckerParams() {
return {
{"regex.replace('abc', 'a', 's') == 'sbc'"},
{"regex.replace('abc', 'a', 's') == 121",
"found no matching overload for '_==_' applied to '(string, int)"},
{"regex.replace('abc', 'j', '1', 2) == 9.0",
"found no matching overload for '_==_' applied to '(string, double)"},
{"regex.extractAll('banananana', '(ana)') == ['ana', 'ana']"},
{"regex.extract('foo bar', 'f') == 121",
"found no matching overload for '_==_' applied to "
"'(optional_type(string), int)'"},
};
}

INSTANTIATE_TEST_SUITE_P(RegexExtCheckerLibraryTest, RegexExtCheckerLibraryTest,
ValuesIn(createRegexCheckerParams()));
} // namespace
} // namespace cel::extensions