From 7737ed163f0aecb3609f186a342da03e2b32dde2 Mon Sep 17 00:00:00 2001 From: Jonathan Tatum Date: Mon, 3 Nov 2025 16:17:53 -0800 Subject: [PATCH] Add is_contextual flag to FunctionDescriptor. This is used to mark a function as impure or context dependent. This blocks constant folding from attempting to evaluate the function. PiperOrigin-RevId: 827691703 --- common/function_descriptor.h | 62 +++++++++++++++---- eval/compiler/constant_folding.cc | 11 ++++ runtime/BUILD | 6 +- runtime/constant_folding.cc | 5 +- runtime/constant_folding_test.cc | 96 ++++++++++++++++++++++++++++-- runtime/function_adapter.h | 68 +++++++++++++++++---- runtime/register_function_helper.h | 12 +++- 7 files changed, 226 insertions(+), 34 deletions(-) diff --git a/common/function_descriptor.h b/common/function_descriptor.h index 9c1f8a5bd..75c61e13a 100644 --- a/common/function_descriptor.h +++ b/common/function_descriptor.h @@ -26,22 +26,52 @@ namespace cel { +struct FunctionDescriptorOptions { + // If true (strict, default), error or unknown arguments are propagated + // instead of calling the function. if false (non-strict), the function may + // receive error or unknown values as arguments. + bool is_strict = true; + + // Whether the function is impure or context-sensitive. + // + // Impure functions depend on state other than the arguments received during + // the CEL expression evaluation or have visible side effects. This breaks + // some of the assumptions of the CEL evaluation model. This flag is used as a + // hint to the planner that some optimizations are not safe or not effective. + bool is_contextual = false; +}; + // Coarsely describes a function for the purpose of runtime resolution of // overloads. class FunctionDescriptor final { public: FunctionDescriptor(absl::string_view name, bool receiver_style, - std::vector types, bool is_strict = true) - : impl_(std::make_shared(name, receiver_style, std::move(types), - is_strict)) {} + std::vector types, bool is_strict) + : impl_(std::make_shared( + name, std::move(types), receiver_style, + FunctionDescriptorOptions{is_strict, + /*is_contextual=*/false})) {} + + FunctionDescriptor(absl::string_view name, bool receiver_style, + std::vector types, bool is_strict, + bool is_contextual) + : impl_(std::make_shared( + name, std::move(types), receiver_style, + FunctionDescriptorOptions{is_strict, is_contextual})) {} + + FunctionDescriptor(absl::string_view name, bool is_receiver_style, + std::vector types, + FunctionDescriptorOptions options = {}) + : impl_(std::make_shared(name, std::move(types), is_receiver_style, + options)) {} // Function name. const std::string& name() const { return impl_->name; } // Whether function is receiver style i.e. true means arg0.name(args[1:]...). - bool receiver_style() const { return impl_->receiver_style; } + bool receiver_style() const { return impl_->is_receiver_style; } - // The argmument types the function accepts. + // The argument types the function accepts. // // TODO(uncreated-issue/17): make this kinds const std::vector& types() const { return impl_->types; } @@ -49,7 +79,15 @@ class FunctionDescriptor final { // if true (strict, default), error or unknown arguments are propagated // instead of calling the function. if false (non-strict), the function may // receive error or unknown values as arguments. - bool is_strict() const { return impl_->is_strict; } + bool is_strict() const { return impl_->options.is_strict; } + + // Whether the function is contextual (impure). + // + // Contextual functions depend on state other than the arguments received in + // the CEL expression evaluation or have visible side effects. This breaks + // some of the assumptions of CEL. This flag is used as a hint to the planner + // that some optimizations are not safe or not effective. + bool is_contextual() const { return impl_->options.is_contextual; } // Helper for matching a descriptor. This tests that the shape is the same -- // |other| accepts the same number and types of arguments and is the same call @@ -65,17 +103,17 @@ class FunctionDescriptor final { private: struct Impl final { - Impl(absl::string_view name, bool receiver_style, std::vector types, - bool is_strict) + Impl(absl::string_view name, std::vector types, + bool is_receiver_style, FunctionDescriptorOptions options) : name(name), types(std::move(types)), - receiver_style(receiver_style), - is_strict(is_strict) {} + is_receiver_style(is_receiver_style), + options(options) {} std::string name; std::vector types; - bool receiver_style; - bool is_strict; + bool is_receiver_style; + FunctionDescriptorOptions options; }; std::shared_ptr impl_; diff --git a/eval/compiler/constant_folding.cc b/eval/compiler/constant_folding.cc index ff04379d2..118fc94c5 100644 --- a/eval/compiler/constant_folding.cc +++ b/eval/compiler/constant_folding.cc @@ -155,6 +155,17 @@ IsConst IsConstExpr(const Expr& expr, const Resolver& resolver) { return IsConst::kNonConst; } + auto overloads = + resolver.FindOverloads(call.function(), call.has_target(), arg_len); + // Check for any contextual overloads. If there are any, we cowardly + // avoid constant folding instead of trying to check if one of the + // overloads would be safe to use. + for (const auto& overload : overloads) { + if (overload.descriptor.is_contextual()) { + return IsConst::kNonConst; + } + } + return IsConst::kConditional; } case ExprKindCase::kUnspecifiedExpr: diff --git a/runtime/BUILD b/runtime/BUILD index cfd8cd361..bd66b4a67 100644 --- a/runtime/BUILD +++ b/runtime/BUILD @@ -116,6 +116,7 @@ cc_library( deps = [ ":function_registry", + "//common:function_descriptor", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], @@ -320,7 +321,7 @@ cc_library( deps = [ ":runtime", ":runtime_builder", - "//common:native_type", + "//common:typeinfo", "//eval/compiler:constant_folding", "//internal:casts", "//internal:noop_delete", @@ -342,11 +343,14 @@ cc_test( deps = [ ":activation", ":constant_folding", + ":function", ":register_function_helper", ":runtime_builder", ":runtime_options", ":standard_runtime_builder_factory", "//base:function_adapter", + "//common:function_descriptor", + "//common:kind", "//common:value", "//extensions/protobuf:runtime_adapter", "//internal:testing", diff --git a/runtime/constant_folding.cc b/runtime/constant_folding.cc index f30e3947a..2d14154dc 100644 --- a/runtime/constant_folding.cc +++ b/runtime/constant_folding.cc @@ -22,7 +22,7 @@ #include "absl/log/absl_check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "common/native_type.h" +#include "common/typeinfo.h" #include "eval/compiler/constant_folding.h" #include "internal/casts.h" #include "internal/noop_delete.h" @@ -44,8 +44,7 @@ using ::cel::runtime_internal::RuntimeImpl; absl::StatusOr RuntimeImplFromBuilder( RuntimeBuilder& builder ABSL_ATTRIBUTE_LIFETIME_BOUND) { Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder); - if (RuntimeFriendAccess::RuntimeTypeId(runtime) != - NativeTypeId::For()) { + if (RuntimeFriendAccess::RuntimeTypeId(runtime) != TypeId()) { return absl::UnimplementedError( "constant folding only supported on the default cel::Runtime " "implementation."); diff --git a/runtime/constant_folding_test.cc b/runtime/constant_folding_test.cc index 76bcdbf5c..c59d5602a 100644 --- a/runtime/constant_folding_test.cc +++ b/runtime/constant_folding_test.cc @@ -14,6 +14,7 @@ #include "runtime/constant_folding.h" +#include #include #include #include @@ -25,13 +26,13 @@ #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "base/function_adapter.h" +#include "common/function_descriptor.h" #include "common/value.h" #include "extensions/protobuf/runtime_adapter.h" #include "internal/testing.h" #include "internal/testing_descriptor_pool.h" #include "parser/parser.h" #include "runtime/activation.h" -#include "runtime/register_function_helper.h" #include "runtime/runtime_builder.h" #include "runtime/runtime_options.h" #include "runtime/standard_runtime_builder_factory.h" @@ -82,8 +83,8 @@ TEST_P(ConstantFoldingExtTest, Runner) { CreateStandardRuntimeBuilder( internal::GetTestingDescriptorPool(), options)); - auto status = RegisterHelper, const StringValue&, const StringValue&>>:: + auto status = BinaryFunctionAdapter, const StringValue&, + const StringValue&>:: RegisterGlobalOverload( "prepend", [](const StringValue& value, const StringValue& prefix) { @@ -129,8 +130,7 @@ INSTANTIATE_TEST_SUITE_P( IsBoolValue(true)}, {"runtime_error", "[1, 2, 3, 4].exists(x, ['4'].all(y, y <= x))", IsErrorValue("No matching overloads")}, - // TODO(uncreated-issue/32): Depends on map creation - // {"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", 2}, + {"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", IsIntValue(2)}, {"custom_function", "prepend('def', 'abc') == 'abcdef'", IsBoolValue(true)}}), @@ -138,5 +138,91 @@ INSTANTIATE_TEST_SUITE_P( return info.param.name; }); +TEST(ConstantFoldingExtTest, LazyFunctionNotFolded) { + google::protobuf::Arena arena; + RuntimeOptions options; + + ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + int call_count = 0; + using FunctionAdapter = + BinaryFunctionAdapter, const StringValue&, + const StringValue&>; + auto fn = FunctionAdapter::WrapFunction( + [&call_count](const StringValue& value, const StringValue& prefix) { + call_count++; + return StringValue(absl::StrCat(prefix.ToString(), value.ToString())); + }); + FunctionDescriptor descriptor = FunctionAdapter::CreateDescriptor( + "lazy_prepend", /*receiver_style=*/false); + ASSERT_THAT(builder.function_registry().RegisterLazyFunction(descriptor), + IsOk()); + + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("lazy_prepend('def', 'abc') == 'abcdef'")); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + EXPECT_EQ(call_count, 0); + Activation activation; + activation.InsertFunction(descriptor, std::move(fn)); + + ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 1); + EXPECT_THAT(result, IsBoolValue(true)); + + ASSERT_OK_AND_ASSIGN(result, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 2); + EXPECT_THAT(result, IsBoolValue(true)); +} + +TEST(ConstantFoldingExtTest, ContextualFunctionNotFolded) { + google::protobuf::Arena arena; + RuntimeOptions options; + ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder, + CreateStandardRuntimeBuilder( + internal::GetTestingDescriptorPool(), options)); + int call_count = 0; + + auto status = BinaryFunctionAdapter< + absl::StatusOr, const StringValue&, + const StringValue&>::Register("contextual_prepend", + /*receiver_style=*/false, + [&call_count](const StringValue& value, + const StringValue& prefix) { + call_count++; + return StringValue(absl::StrCat( + prefix.ToString(), value.ToString())); + }, + builder.function_registry(), + {/*.is_strict=*/true, + /*is_contextual=*/true}); + ASSERT_THAT(status, IsOk()); + + ASSERT_THAT(EnableConstantFolding(builder), IsOk()); + + ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build()); + + ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr, + Parse("contextual_prepend('def', 'abc') == 'abcdef'")); + + ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram( + *runtime, parsed_expr)); + EXPECT_EQ(call_count, 0); + Activation activation; + ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 1); + EXPECT_THAT(value, IsBoolValue(true)); + + ASSERT_OK_AND_ASSIGN(value, program->Evaluate(&arena, activation)); + EXPECT_EQ(call_count, 2); + EXPECT_THAT(value, IsBoolValue(true)); +} + } // namespace } // namespace cel::extensions diff --git a/runtime/function_adapter.h b/runtime/function_adapter.h index 0b97e3515..1c96a6ea1 100644 --- a/runtime/function_adapter.h +++ b/runtime/function_adapter.h @@ -212,8 +212,15 @@ class NullaryFunctionAdapter static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, - bool is_strict = true) { - return FunctionDescriptor(name, receiver_style, {}, is_strict); + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { + return FunctionDescriptor(name, receiver_style, {}, options); } private: @@ -288,9 +295,17 @@ class UnaryFunctionAdapter : public RegisterHelper> { static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, - bool is_strict = true) { + bool is_strict) { + return CreateDescriptor( + name, receiver_style, + FunctionDescriptorOptions{is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { return FunctionDescriptor(name, receiver_style, - {runtime_internal::AdaptedKind()}, is_strict); + {runtime_internal::AdaptedKind()}, options); } private: @@ -419,11 +434,18 @@ class BinaryFunctionAdapter static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, - bool is_strict = true) { + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { return FunctionDescriptor(name, receiver_style, {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind()}, - is_strict); + options); } private: @@ -491,12 +513,20 @@ class TernaryFunctionAdapter static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, - bool is_strict = true) { + bool is_strict) { + return CreateDescriptor( + name, receiver_style, + FunctionDescriptorOptions{is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { return FunctionDescriptor( name, receiver_style, {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind()}, - is_strict); + options); } private: @@ -570,13 +600,20 @@ class QuaternaryFunctionAdapter static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, - bool is_strict = true) { + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { return FunctionDescriptor( name, receiver_style, {runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind(), runtime_internal::AdaptedKind()}, - is_strict); + options); } private: @@ -664,10 +701,17 @@ class NaryFunctionAdapter static FunctionDescriptor CreateDescriptor(absl::string_view name, bool receiver_style, - bool is_strict = true) { + bool is_strict) { + return CreateDescriptor(name, receiver_style, + {is_strict, /*is_contextual=*/false}); + } + + static FunctionDescriptor CreateDescriptor( + absl::string_view name, bool receiver_style, + FunctionDescriptorOptions options = {}) { return FunctionDescriptor(name, receiver_style, {runtime_internal::AdaptedKind()...}, - is_strict); + options); } static std::unique_ptr WrapFunction(FunctionType fn) { diff --git a/runtime/register_function_helper.h b/runtime/register_function_helper.h index fbeec84bf..8cc133abc 100644 --- a/runtime/register_function_helper.h +++ b/runtime/register_function_helper.h @@ -19,6 +19,7 @@ #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "common/function_descriptor.h" #include "runtime/function_registry.h" namespace cel { @@ -44,12 +45,21 @@ class RegisterHelper { template static absl::Status Register(absl::string_view name, bool receiver_style, FunctionT&& fn, FunctionRegistry& registry, - bool strict = true) { + bool strict) { return registry.Register( AdapterT::CreateDescriptor(name, receiver_style, strict), AdapterT::WrapFunction(std::forward(fn))); } + template + static absl::Status Register(absl::string_view name, bool receiver_style, + FunctionT&& fn, FunctionRegistry& registry, + FunctionDescriptorOptions options = {}) { + return registry.Register( + AdapterT::CreateDescriptor(name, receiver_style, options), + AdapterT::WrapFunction(std::forward(fn))); + } + // Registers a global overload (.e.g. size() ) template static absl::Status RegisterGlobalOverload(absl::string_view name,