Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 50 additions & 12 deletions common/function_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,30 +26,68 @@

namespace cel {

struct FunctionDescriptorOptions {
// If true (strict, default), error or unknown arguments are propagated
// instead of calling the function. if false (non-strict), the function may
// receive error or unknown values as arguments.
bool is_strict = true;

// Whether the function is impure or context-sensitive.
//
// Impure functions depend on state other than the arguments received during
// the CEL expression evaluation or have visible side effects. This breaks
// some of the assumptions of the CEL evaluation model. This flag is used as a
// hint to the planner that some optimizations are not safe or not effective.
bool is_contextual = false;
};

// Coarsely describes a function for the purpose of runtime resolution of
// overloads.
class FunctionDescriptor final {
public:
FunctionDescriptor(absl::string_view name, bool receiver_style,
std::vector<Kind> types, bool is_strict = true)
: impl_(std::make_shared<Impl>(name, receiver_style, std::move(types),
is_strict)) {}
std::vector<Kind> types, bool is_strict)
: impl_(std::make_shared<Impl>(
name, std::move(types), receiver_style,
FunctionDescriptorOptions{is_strict,
/*is_contextual=*/false})) {}

FunctionDescriptor(absl::string_view name, bool receiver_style,
std::vector<Kind> types, bool is_strict,
bool is_contextual)
: impl_(std::make_shared<Impl>(
name, std::move(types), receiver_style,
FunctionDescriptorOptions{is_strict, is_contextual})) {}

FunctionDescriptor(absl::string_view name, bool is_receiver_style,
std::vector<Kind> types,
FunctionDescriptorOptions options = {})
: impl_(std::make_shared<Impl>(name, std::move(types), is_receiver_style,
options)) {}

// Function name.
const std::string& name() const { return impl_->name; }

// Whether function is receiver style i.e. true means arg0.name(args[1:]...).
bool receiver_style() const { return impl_->receiver_style; }
bool receiver_style() const { return impl_->is_receiver_style; }

// The argmument types the function accepts.
// The argument types the function accepts.
//
// TODO(uncreated-issue/17): make this kinds
const std::vector<Kind>& types() const { return impl_->types; }

// if true (strict, default), error or unknown arguments are propagated
// instead of calling the function. if false (non-strict), the function may
// receive error or unknown values as arguments.
bool is_strict() const { return impl_->is_strict; }
bool is_strict() const { return impl_->options.is_strict; }

// Whether the function is contextual (impure).
//
// Contextual functions depend on state other than the arguments received in
// the CEL expression evaluation or have visible side effects. This breaks
// some of the assumptions of CEL. This flag is used as a hint to the planner
// that some optimizations are not safe or not effective.
bool is_contextual() const { return impl_->options.is_contextual; }

// Helper for matching a descriptor. This tests that the shape is the same --
// |other| accepts the same number and types of arguments and is the same call
Expand All @@ -65,17 +103,17 @@ class FunctionDescriptor final {

private:
struct Impl final {
Impl(absl::string_view name, bool receiver_style, std::vector<Kind> types,
bool is_strict)
Impl(absl::string_view name, std::vector<Kind> types,
bool is_receiver_style, FunctionDescriptorOptions options)
: name(name),
types(std::move(types)),
receiver_style(receiver_style),
is_strict(is_strict) {}
is_receiver_style(is_receiver_style),
options(options) {}

std::string name;
std::vector<Kind> types;
bool receiver_style;
bool is_strict;
bool is_receiver_style;
FunctionDescriptorOptions options;
};

std::shared_ptr<const Impl> impl_;
Expand Down
11 changes: 11 additions & 0 deletions eval/compiler/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ IsConst IsConstExpr(const Expr& expr, const Resolver& resolver) {
return IsConst::kNonConst;
}

auto overloads =
resolver.FindOverloads(call.function(), call.has_target(), arg_len);
// Check for any contextual overloads. If there are any, we cowardly
// avoid constant folding instead of trying to check if one of the
// overloads would be safe to use.
for (const auto& overload : overloads) {
if (overload.descriptor.is_contextual()) {
return IsConst::kNonConst;
}
}

return IsConst::kConditional;
}
case ExprKindCase::kUnspecifiedExpr:
Expand Down
6 changes: 5 additions & 1 deletion runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ cc_library(
deps =
[
":function_registry",
"//common:function_descriptor",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
Expand Down Expand Up @@ -320,7 +321,7 @@ cc_library(
deps = [
":runtime",
":runtime_builder",
"//common:native_type",
"//common:typeinfo",
"//eval/compiler:constant_folding",
"//internal:casts",
"//internal:noop_delete",
Expand All @@ -342,11 +343,14 @@ cc_test(
deps = [
":activation",
":constant_folding",
":function",
":register_function_helper",
":runtime_builder",
":runtime_options",
":standard_runtime_builder_factory",
"//base:function_adapter",
"//common:function_descriptor",
"//common:kind",
"//common:value",
"//extensions/protobuf:runtime_adapter",
"//internal:testing",
Expand Down
5 changes: 2 additions & 3 deletions runtime/constant_folding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "common/native_type.h"
#include "common/typeinfo.h"
#include "eval/compiler/constant_folding.h"
#include "internal/casts.h"
#include "internal/noop_delete.h"
Expand All @@ -44,8 +44,7 @@ using ::cel::runtime_internal::RuntimeImpl;
absl::StatusOr<RuntimeImpl* absl_nonnull> RuntimeImplFromBuilder(
RuntimeBuilder& builder ABSL_ATTRIBUTE_LIFETIME_BOUND) {
Runtime& runtime = RuntimeFriendAccess::GetMutableRuntime(builder);
if (RuntimeFriendAccess::RuntimeTypeId(runtime) !=
NativeTypeId::For<RuntimeImpl>()) {
if (RuntimeFriendAccess::RuntimeTypeId(runtime) != TypeId<RuntimeImpl>()) {
return absl::UnimplementedError(
"constant folding only supported on the default cel::Runtime "
"implementation.");
Expand Down
96 changes: 91 additions & 5 deletions runtime/constant_folding_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "runtime/constant_folding.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>
Expand All @@ -25,13 +26,13 @@
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "base/function_adapter.h"
#include "common/function_descriptor.h"
#include "common/value.h"
#include "extensions/protobuf/runtime_adapter.h"
#include "internal/testing.h"
#include "internal/testing_descriptor_pool.h"
#include "parser/parser.h"
#include "runtime/activation.h"
#include "runtime/register_function_helper.h"
#include "runtime/runtime_builder.h"
#include "runtime/runtime_options.h"
#include "runtime/standard_runtime_builder_factory.h"
Expand Down Expand Up @@ -82,8 +83,8 @@ TEST_P(ConstantFoldingExtTest, Runner) {
CreateStandardRuntimeBuilder(
internal::GetTestingDescriptorPool(), options));

auto status = RegisterHelper<BinaryFunctionAdapter<
absl::StatusOr<Value>, const StringValue&, const StringValue&>>::
auto status = BinaryFunctionAdapter<absl::StatusOr<Value>, const StringValue&,
const StringValue&>::
RegisterGlobalOverload(
"prepend",
[](const StringValue& value, const StringValue& prefix) {
Expand Down Expand Up @@ -129,14 +130,99 @@ INSTANTIATE_TEST_SUITE_P(
IsBoolValue(true)},
{"runtime_error", "[1, 2, 3, 4].exists(x, ['4'].all(y, y <= x))",
IsErrorValue("No matching overloads")},
// TODO(uncreated-issue/32): Depends on map creation
// {"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", 2},
{"map_create", "{'abc': 'def', 'abd': 'deg'}.size()", IsIntValue(2)},
{"custom_function", "prepend('def', 'abc') == 'abcdef'",
IsBoolValue(true)}}),

[](const testing::TestParamInfo<TestCase>& info) {
return info.param.name;
});

TEST(ConstantFoldingExtTest, LazyFunctionNotFolded) {
google::protobuf::Arena arena;
RuntimeOptions options;

ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder,
CreateStandardRuntimeBuilder(
internal::GetTestingDescriptorPool(), options));
int call_count = 0;
using FunctionAdapter =
BinaryFunctionAdapter<absl::StatusOr<Value>, const StringValue&,
const StringValue&>;
auto fn = FunctionAdapter::WrapFunction(
[&call_count](const StringValue& value, const StringValue& prefix) {
call_count++;
return StringValue(absl::StrCat(prefix.ToString(), value.ToString()));
});
FunctionDescriptor descriptor = FunctionAdapter::CreateDescriptor(
"lazy_prepend", /*receiver_style=*/false);
ASSERT_THAT(builder.function_registry().RegisterLazyFunction(descriptor),
IsOk());

ASSERT_THAT(EnableConstantFolding(builder), IsOk());

ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());

ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr,
Parse("lazy_prepend('def', 'abc') == 'abcdef'"));

ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram(
*runtime, parsed_expr));
EXPECT_EQ(call_count, 0);
Activation activation;
activation.InsertFunction(descriptor, std::move(fn));

ASSERT_OK_AND_ASSIGN(Value result, program->Evaluate(&arena, activation));
EXPECT_EQ(call_count, 1);
EXPECT_THAT(result, IsBoolValue(true));

ASSERT_OK_AND_ASSIGN(result, program->Evaluate(&arena, activation));
EXPECT_EQ(call_count, 2);
EXPECT_THAT(result, IsBoolValue(true));
}

TEST(ConstantFoldingExtTest, ContextualFunctionNotFolded) {
google::protobuf::Arena arena;
RuntimeOptions options;
ASSERT_OK_AND_ASSIGN(cel::RuntimeBuilder builder,
CreateStandardRuntimeBuilder(
internal::GetTestingDescriptorPool(), options));
int call_count = 0;

auto status = BinaryFunctionAdapter<
absl::StatusOr<Value>, const StringValue&,
const StringValue&>::Register("contextual_prepend",
/*receiver_style=*/false,
[&call_count](const StringValue& value,
const StringValue& prefix) {
call_count++;
return StringValue(absl::StrCat(
prefix.ToString(), value.ToString()));
},
builder.function_registry(),
{/*.is_strict=*/true,
/*is_contextual=*/true});
ASSERT_THAT(status, IsOk());

ASSERT_THAT(EnableConstantFolding(builder), IsOk());

ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());

ASSERT_OK_AND_ASSIGN(ParsedExpr parsed_expr,
Parse("contextual_prepend('def', 'abc') == 'abcdef'"));

ASSERT_OK_AND_ASSIGN(auto program, ProtobufRuntimeAdapter::CreateProgram(
*runtime, parsed_expr));
EXPECT_EQ(call_count, 0);
Activation activation;
ASSERT_OK_AND_ASSIGN(Value value, program->Evaluate(&arena, activation));
EXPECT_EQ(call_count, 1);
EXPECT_THAT(value, IsBoolValue(true));

ASSERT_OK_AND_ASSIGN(value, program->Evaluate(&arena, activation));
EXPECT_EQ(call_count, 2);
EXPECT_THAT(value, IsBoolValue(true));
}

} // namespace
} // namespace cel::extensions
Loading