From eb0273d3207a5b96d948408adc5b01ddf99c599b Mon Sep 17 00:00:00 2001 From: Sokwhan Huh Date: Tue, 25 Nov 2025 17:47:35 -0800 Subject: [PATCH] Plan calls PiperOrigin-RevId: 836872858 --- .../src/main/java/dev/cel/runtime/BUILD.bazel | 4 + .../dev/cel/runtime/CelResolvedOverload.java | 15 +- .../dev/cel/runtime/DefaultDispatcher.java | 25 +- .../java/dev/cel/runtime/planner/BUILD.bazel | 42 +++ .../dev/cel/runtime/planner/EvalUnary.java | 66 +++++ .../cel/runtime/planner/EvalVarArgsCall.java | 70 +++++ .../cel/runtime/planner/EvalZeroArity.java | 63 ++++ .../cel/runtime/planner/PlannedProgram.java | 18 +- .../cel/runtime/planner/ProgramPlanner.java | 126 +++++++- .../java/dev/cel/runtime/planner/BUILD.bazel | 17 ++ .../runtime/planner/ProgramPlannerTest.java | 273 +++++++++++++++++- 11 files changed, 704 insertions(+), 15 deletions(-) create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java create mode 100644 runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java diff --git a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel index 35d5358e9..847092cc3 100644 --- a/runtime/src/main/java/dev/cel/runtime/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/BUILD.bazel @@ -128,6 +128,7 @@ java_library( ":resolved_overload", "//:auto_value", "//common:error_codes", + "//common/annotations", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -146,6 +147,7 @@ cel_android_library( ":resolved_overload_android", "//:auto_value", "//common:error_codes", + "//common/annotations", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", @@ -1172,6 +1174,7 @@ java_library( ":function_overload", ":unknown_attributes", "//:auto_value", + "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", ], @@ -1186,6 +1189,7 @@ cel_android_library( ":function_overload_android", ":unknown_attributes_android", "//:auto_value", + "//common/annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven_android//:com_google_guava_guava", ], diff --git a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java index f7cd9f2c0..f6a0c4f99 100644 --- a/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java +++ b/runtime/src/main/java/dev/cel/runtime/CelResolvedOverload.java @@ -17,6 +17,7 @@ import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.annotations.Internal; import java.util.List; import java.util.Map; @@ -26,7 +27,8 @@ */ @AutoValue @Immutable -abstract class CelResolvedOverload { +@Internal +public abstract class CelResolvedOverload { /** The overload id of the function. */ public abstract String getOverloadId(); @@ -78,7 +80,14 @@ public static CelResolvedOverload of( * Returns true if the overload's expected argument types match the types of the given arguments. */ boolean canHandle(Object[] arguments) { - ImmutableList> parameterTypes = getParameterTypes(); + return canHandle(arguments, getParameterTypes(), isStrict()); + } + + /** + * Returns true if the overload's expected argument types match the types of the given arguments. + */ + public static boolean canHandle( + Object[] arguments, ImmutableList> parameterTypes, boolean isStrict) { if (parameterTypes.size() != arguments.length) { return false; } @@ -96,7 +105,7 @@ boolean canHandle(Object[] arguments) { if (arg instanceof Exception || arg instanceof CelUnknownSet) { // Only non-strict functions can accept errors/unknowns as arguments to a function - if (!isStrict()) { + if (!isStrict) { // Skip assignability check below, but continue to validate remaining args continue; } diff --git a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java index adc99deb7..35ce243b3 100644 --- a/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java +++ b/runtime/src/main/java/dev/cel/runtime/DefaultDispatcher.java @@ -14,6 +14,7 @@ package dev.cel.runtime; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; import com.google.auto.value.AutoBuilder; @@ -22,17 +23,27 @@ import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.Immutable; import dev.cel.common.CelErrorCode; +import dev.cel.common.annotations.Internal; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; -/** Default implementation of dispatcher. */ +/** + * Default implementation of dispatcher. + * + *

CEL Library Internals. Do Not Use. + */ @Immutable -final class DefaultDispatcher implements CelFunctionResolver { +@Internal +public final class DefaultDispatcher implements CelFunctionResolver { private final ImmutableMap overloads; + public Optional findOverload(String functionName) { + return Optional.ofNullable(overloads.get(functionName)); + } + @Override public Optional findOverloadMatchingArgs( String functionName, List overloadIds, Object[] args) throws CelEvaluationException { @@ -101,24 +112,26 @@ Optional findSingleNonStrictOverload(List overloadI return Optional.empty(); } - static Builder newBuilder() { + public static Builder newBuilder() { return new AutoBuilder_DefaultDispatcher_Builder(); } + /** Builder for {@link DefaultDispatcher}. */ @AutoBuilder(ofClass = DefaultDispatcher.class) - abstract static class Builder { + public abstract static class Builder { abstract ImmutableMap overloads(); abstract ImmutableMap.Builder overloadsBuilder(); @CanIgnoreReturnValue - Builder addOverload( + public Builder addOverload( String overloadId, List> argTypes, boolean isStrict, CelFunctionOverload overload) { checkNotNull(overloadId); + checkArgument(!overloadId.isEmpty(), "Overload ID cannot be empty."); checkNotNull(argTypes); checkNotNull(overload); @@ -127,7 +140,7 @@ Builder addOverload( return this; } - abstract DefaultDispatcher build(); + public abstract DefaultDispatcher build(); } DefaultDispatcher(ImmutableMap overloads) { diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index a0423310c..ee8987886 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -19,6 +19,9 @@ java_library( ":eval_create_list", ":eval_create_map", ":eval_create_struct", + ":eval_unary", + ":eval_var_args_call", + ":eval_zero_arity", ":planned_program", "//:auto_value", "//common:cel_ast", @@ -28,10 +31,12 @@ java_library( "//common/types", "//common/types:type_providers", "//common/values:cel_value_provider", + "//runtime:dispatcher", "//runtime:evaluation_exception", "//runtime:evaluation_exception_builder", "//runtime:interpretable", "//runtime:program", + "//runtime:resolved_overload", "@maven//:com_google_code_findbugs_annotations", "@maven//:com_google_errorprone_error_prone_annotations", "@maven//:com_google_guava_guava", @@ -45,6 +50,7 @@ java_library( "//:auto_value", "//runtime:activation", "//runtime:evaluation_exception", + "//runtime:evaluation_exception_builder", "//runtime:function_resolver", "//runtime:interpretable", "//runtime:program", @@ -95,6 +101,42 @@ java_library( ], ) +java_library( + name = "eval_zero_arity", + srcs = ["EvalZeroArity.java"], + deps = [ + "//runtime:evaluation_exception", + "//runtime:evaluation_listener", + "//runtime:function_resolver", + "//runtime:interpretable", + "//runtime:resolved_overload", + ], +) + +java_library( + name = "eval_unary", + srcs = ["EvalUnary.java"], + deps = [ + "//runtime:evaluation_exception", + "//runtime:evaluation_listener", + "//runtime:function_resolver", + "//runtime:interpretable", + "//runtime:resolved_overload", + ], +) + +java_library( + name = "eval_var_args_call", + srcs = ["EvalVarArgsCall.java"], + deps = [ + "//runtime:evaluation_exception", + "//runtime:evaluation_listener", + "//runtime:function_resolver", + "//runtime:interpretable", + "//runtime:resolved_overload", + ], +) + java_library( name = "eval_create_struct", srcs = ["EvalCreateStruct.java"], diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java new file mode 100644 index 000000000..c6daff4b2 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalUnary.java @@ -0,0 +1,66 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationListener; +import dev.cel.runtime.CelFunctionResolver; +import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.GlobalResolver; +import dev.cel.runtime.Interpretable; + +final class EvalUnary implements Interpretable { + + private final CelResolvedOverload resolvedOverload; + private final Interpretable arg; + + @Override + public Object eval(GlobalResolver resolver) throws CelEvaluationException { + Object argVal = arg.eval(resolver); + Object[] arguments = new Object[] {argVal}; + + return resolvedOverload.getDefinition().apply(arguments); + } + + @Override + public Object eval(GlobalResolver resolver, CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval(GlobalResolver resolver, CelFunctionResolver lateBoundFunctionResolver) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval( + GlobalResolver resolver, + CelFunctionResolver lateBoundFunctionResolver, + CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + static EvalUnary create(CelResolvedOverload resolvedOverload, Interpretable arg) { + return new EvalUnary(resolvedOverload, arg); + } + + private EvalUnary(CelResolvedOverload resolvedOverload, Interpretable arg) { + this.resolvedOverload = resolvedOverload; + this.arg = arg; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java new file mode 100644 index 000000000..48fc7ba04 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalVarArgsCall.java @@ -0,0 +1,70 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationListener; +import dev.cel.runtime.CelFunctionResolver; +import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.GlobalResolver; +import dev.cel.runtime.Interpretable; + +@SuppressWarnings("Immutable") +final class EvalVarArgsCall implements Interpretable { + + private final CelResolvedOverload resolvedOverload; + private final Interpretable[] args; + + @Override + public Object eval(GlobalResolver resolver) throws CelEvaluationException { + Object[] argVals = new Object[args.length]; + for (int i = 0; i < args.length; i++) { + Interpretable arg = args[i]; + argVals[i] = arg.eval(resolver); + } + + return resolvedOverload.getDefinition().apply(argVals); + } + + @Override + public Object eval(GlobalResolver resolver, CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval(GlobalResolver resolver, CelFunctionResolver lateBoundFunctionResolver) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval( + GlobalResolver resolver, + CelFunctionResolver lateBoundFunctionResolver, + CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + static EvalVarArgsCall create(CelResolvedOverload resolvedOverload, Interpretable[] args) { + return new EvalVarArgsCall(resolvedOverload, args); + } + + private EvalVarArgsCall(CelResolvedOverload resolvedOverload, Interpretable[] args) { + this.resolvedOverload = resolvedOverload; + this.args = args; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java new file mode 100644 index 000000000..813b84629 --- /dev/null +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalZeroArity.java @@ -0,0 +1,63 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package dev.cel.runtime.planner; + +import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationListener; +import dev.cel.runtime.CelFunctionResolver; +import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.GlobalResolver; +import dev.cel.runtime.Interpretable; + +final class EvalZeroArity implements Interpretable { + + private static final Object[] EMPTY_ARRAY = new Object[0]; + + private final CelResolvedOverload resolvedOverload; + + @Override + public Object eval(GlobalResolver resolver) throws CelEvaluationException { + return resolvedOverload.getDefinition().apply(EMPTY_ARRAY); + } + + @Override + public Object eval(GlobalResolver resolver, CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval(GlobalResolver resolver, CelFunctionResolver lateBoundFunctionResolver) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + @Override + public Object eval( + GlobalResolver resolver, + CelFunctionResolver lateBoundFunctionResolver, + CelEvaluationListener listener) { + // TODO: Implement support + throw new UnsupportedOperationException("Not yet supported"); + } + + static EvalZeroArity create(CelResolvedOverload resolvedOverload) { + return new EvalZeroArity(resolvedOverload); + } + + private EvalZeroArity(CelResolvedOverload resolvedOverload) { + this.resolvedOverload = resolvedOverload; + } +} diff --git a/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java b/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java index c8aa1bc0d..34aac58a9 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/PlannedProgram.java @@ -18,6 +18,7 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.runtime.Activation; import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelEvaluationExceptionBuilder; import dev.cel.runtime.CelFunctionResolver; import dev.cel.runtime.GlobalResolver; import dev.cel.runtime.Interpretable; @@ -31,12 +32,12 @@ abstract class PlannedProgram implements Program { @Override public Object eval() throws CelEvaluationException { - return interpretable().eval(GlobalResolver.EMPTY); + return evalOrThrow(interpretable(), GlobalResolver.EMPTY); } @Override public Object eval(Map mapValue) throws CelEvaluationException { - return interpretable().eval(Activation.copyOf(mapValue)); + return evalOrThrow(interpretable(), Activation.copyOf(mapValue)); } @Override @@ -45,6 +46,19 @@ public Object eval(Map mapValue, CelFunctionResolver lateBoundFunctio throw new UnsupportedOperationException("Late bound functions not supported yet"); } + private Object evalOrThrow(Interpretable interpretable, GlobalResolver resolver) + throws CelEvaluationException { + try { + return interpretable.eval(resolver); + } catch (RuntimeException e) { + throw newCelEvaluationException(e); + } + } + + private static CelEvaluationException newCelEvaluationException(Exception e) { + return CelEvaluationExceptionBuilder.newBuilder(e.getMessage()).setCause(e).build(); + } + static Program create(Interpretable interpretable) { return new AutoValue_PlannedProgram(interpretable); } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java index ec71155d6..805be193c 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/ProgramPlanner.java @@ -17,12 +17,14 @@ import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.errorprone.annotations.CheckReturnValue; import javax.annotation.concurrent.ThreadSafe; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelContainer; import dev.cel.common.annotations.Internal; import dev.cel.common.ast.CelConstant; import dev.cel.common.ast.CelExpr; +import dev.cel.common.ast.CelExpr.CelCall; import dev.cel.common.ast.CelExpr.CelList; import dev.cel.common.ast.CelExpr.CelMap; import dev.cel.common.ast.CelExpr.CelStruct; @@ -36,9 +38,12 @@ import dev.cel.common.values.CelValueProvider; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelEvaluationExceptionBuilder; +import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.DefaultDispatcher; import dev.cel.runtime.Interpretable; import dev.cel.runtime.Program; import java.util.NoSuchElementException; +import java.util.Optional; /** * {@code ProgramPlanner} resolves functions, types, and identifiers at plan time given a @@ -50,6 +55,7 @@ public final class ProgramPlanner { private final CelTypeProvider typeProvider; private final CelValueProvider valueProvider; + private final DefaultDispatcher dispatcher; private final AttributeFactory attributeFactory; /** @@ -73,6 +79,8 @@ private Interpretable plan(CelExpr celExpr, PlannerContext ctx) { return planConstant(celExpr.constant()); case IDENT: return planIdent(celExpr, ctx); + case CALL: + return planCall(celExpr, ctx); case LIST: return planCreateList(celExpr, ctx); case STRUCT: @@ -138,6 +146,53 @@ private Interpretable planCheckedIdent( return EvalAttribute.create(attributeFactory.newAbsoluteAttribute(identRef.name())); } + private Interpretable planCall(CelExpr expr, PlannerContext ctx) { + ResolvedFunction resolvedFunction = resolveFunction(expr, ctx.referenceMap()); + CelExpr target = resolvedFunction.target().orElse(null); + int argCount = expr.call().args().size(); + if (target != null) { + argCount++; + } + + Interpretable[] evaluatedArgs = new Interpretable[argCount]; + + int offset = 0; + if (target != null) { + evaluatedArgs[0] = plan(target, ctx); + offset++; + } + + ImmutableList args = expr.call().args(); + for (int argIndex = 0; argIndex < args.size(); argIndex++) { + evaluatedArgs[argIndex + offset] = plan(args.get(argIndex), ctx); + } + + // TODO: Handle all specialized calls (logical operators, conditionals, equals etc) + String functionName = resolvedFunction.functionName(); + + CelResolvedOverload resolvedOverload = null; + if (resolvedFunction.overloadId().isPresent()) { + resolvedOverload = dispatcher.findOverload(resolvedFunction.overloadId().get()).orElse(null); + } + + if (resolvedOverload == null) { + // Parsed-only function dispatch + resolvedOverload = + dispatcher + .findOverload(functionName) + .orElseThrow(() -> new NoSuchElementException("Overload not found: " + functionName)); + } + + switch (argCount) { + case 0: + return EvalZeroArity.create(resolvedOverload); + case 1: + return EvalUnary.create(resolvedOverload, evaluatedArgs[0]); + default: + return EvalVarArgsCall.create(resolvedOverload, evaluatedArgs); + } + } + private Interpretable planCreateStruct(CelExpr celExpr, PlannerContext ctx) { CelStruct struct = celExpr.struct(); CelType structType = @@ -193,6 +248,69 @@ private Interpretable planCreateMap(CelExpr celExpr, PlannerContext ctx) { return EvalCreateMap.create(keys, values); } + /** + * resolveFunction determines the call target, function name, and overload name (when unambiguous) + * from the given call expr. + */ + private ResolvedFunction resolveFunction( + CelExpr expr, ImmutableMap referenceMap) { + CelCall call = expr.call(); + Optional target = call.target(); + String functionName = call.function(); + + CelReference reference = referenceMap.get(expr.id()); + if (reference != null) { + // Checked expression + if (reference.overloadIds().size() == 1) { + ResolvedFunction.Builder builder = + ResolvedFunction.newBuilder() + .setFunctionName(functionName) + .setOverloadId(reference.overloadIds().get(0)); + + target.ifPresent(builder::setTarget); + + return builder.build(); + } + } + + // Parsed-only. + // TODO: Handle containers. + if (!target.isPresent()) { + return ResolvedFunction.newBuilder().setFunctionName(functionName).build(); + } else { + return ResolvedFunction.newBuilder() + .setFunctionName(functionName) + .setTarget(target.get()) + .build(); + } + } + + @AutoValue + abstract static class ResolvedFunction { + + abstract String functionName(); + + abstract Optional target(); + + abstract Optional overloadId(); + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setFunctionName(String functionName); + + abstract Builder setTarget(CelExpr target); + + abstract Builder setOverloadId(String overloadId); + + @CheckReturnValue + abstract ResolvedFunction build(); + } + + private static Builder newBuilder() { + return new AutoValue_ProgramPlanner_ResolvedFunction.Builder(); + } + } + @AutoValue abstract static class PlannerContext { @@ -206,14 +324,16 @@ private static PlannerContext create(CelAbstractSyntaxTree ast) { } public static ProgramPlanner newPlanner( - CelTypeProvider typeProvider, CelValueProvider valueProvider) { - return new ProgramPlanner(typeProvider, valueProvider); + CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) { + return new ProgramPlanner(typeProvider, valueProvider, dispatcher); } - private ProgramPlanner(CelTypeProvider typeProvider, CelValueProvider valueProvider) { + private ProgramPlanner( + CelTypeProvider typeProvider, CelValueProvider valueProvider, DefaultDispatcher dispatcher) { this.typeProvider = typeProvider; this.valueProvider = valueProvider; // TODO: Container support + this.dispatcher = dispatcher; this.attributeFactory = AttributeFactory.newAttributeFactory(CelContainer.newBuilder().build(), typeProvider); } diff --git a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel index 0b36e20ac..94a427685 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/test/java/dev/cel/runtime/planner/BUILD.bazel @@ -17,6 +17,8 @@ java_library( "//common:cel_ast", "//common:cel_descriptor_util", "//common:cel_source", + "//common:compiler_common", + "//common:operator", "//common:options", "//common/ast", "//common/internal:cel_descriptor_pools", @@ -34,8 +36,23 @@ java_library( "//compiler:compiler_builder", "//extensions", "//runtime", + "//runtime:dispatcher", + "//runtime:function_binding", "//runtime:program", + "//runtime:resolved_overload", + "//runtime:runtime_equality", + "//runtime:runtime_helpers", "//runtime/planner:program_planner", + "//runtime/standard:add", + "//runtime/standard:divide", + "//runtime/standard:equals", + "//runtime/standard:greater", + "//runtime/standard:greater_equals", + "//runtime/standard:index", + "//runtime/standard:less", + "//runtime/standard:logical_not", + "//runtime/standard:not_strictly_false", + "//runtime/standard:standard_function", "@cel_spec//proto/cel/expr/conformance/proto3:test_all_types_java_proto", "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", diff --git a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java index 32651d73d..56bc1810a 100644 --- a/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java +++ b/runtime/src/test/java/dev/cel/runtime/planner/ProgramPlannerTest.java @@ -15,12 +15,17 @@ package dev.cel.runtime.planner; import static com.google.common.truth.Truth.assertThat; +import static dev.cel.common.CelFunctionDecl.newFunctionDeclaration; +import static dev.cel.common.CelOverloadDecl.newGlobalOverload; +import static dev.cel.common.CelOverloadDecl.newMemberOverload; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertThrows; +import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; import com.google.common.primitives.UnsignedLong; import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; @@ -28,6 +33,7 @@ import dev.cel.common.CelDescriptorUtil; import dev.cel.common.CelOptions; import dev.cel.common.CelSource; +import dev.cel.common.Operator; import dev.cel.common.ast.CelExpr; import dev.cel.common.internal.CelDescriptorPool; import dev.cel.common.internal.DefaultDescriptorPool; @@ -53,17 +59,36 @@ import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.extensions.CelExtensions; import dev.cel.runtime.CelEvaluationException; +import dev.cel.runtime.CelFunctionBinding; +import dev.cel.runtime.CelFunctionOverload; +import dev.cel.runtime.CelResolvedOverload; +import dev.cel.runtime.DefaultDispatcher; import dev.cel.runtime.Program; +import dev.cel.runtime.RuntimeEquality; +import dev.cel.runtime.RuntimeHelpers; +import dev.cel.runtime.standard.AddOperator; +import dev.cel.runtime.standard.CelStandardFunction; +import dev.cel.runtime.standard.DivideOperator; +import dev.cel.runtime.standard.EqualsOperator; +import dev.cel.runtime.standard.GreaterEqualsOperator; +import dev.cel.runtime.standard.GreaterOperator; +import dev.cel.runtime.standard.IndexOperator; +import dev.cel.runtime.standard.LessOperator; +import dev.cel.runtime.standard.LogicalNotOperator; +import dev.cel.runtime.standard.NotStrictlyFalseFunction; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class ProgramPlannerTest { // Note that the following deps will be built from top-level builder APIs + private static final CelOptions CEL_OPTIONS = CelOptions.current().build(); private static final CelTypeProvider TYPE_PROVIDER = new CombinedCelTypeProvider( DefaultTypeProvider.getInstance(), new ProtoMessageTypeProvider(ImmutableSet.of(TestAllTypes.getDescriptor()))); + private static final RuntimeEquality RUNTIME_EQUALITY = + RuntimeEquality.create(RuntimeHelpers.create(), CEL_OPTIONS); private static final CelDescriptorPool DESCRIPTOR_POOL = DefaultDescriptorPool.create( CelDescriptorUtil.getAllDescriptorsFromFileDescriptor( @@ -74,14 +99,160 @@ public final class ProgramPlannerTest { ProtoMessageValueProvider.newInstance(CelOptions.DEFAULT, DYNAMIC_PROTO); private static final ProgramPlanner PLANNER = - ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER); + ProgramPlanner.newPlanner(TYPE_PROVIDER, VALUE_PROVIDER, newDispatcher()); private static final CelCompiler CEL_COMPILER = CelCompilerFactory.standardCelCompilerBuilder() + .addVar("map_var", MapType.create(SimpleType.STRING, SimpleType.DYN)) .addVar("int_var", SimpleType.INT) + .addVar("dyn_var", SimpleType.DYN) + .addFunctionDeclarations( + newFunctionDeclaration("zero", newGlobalOverload("zero_overload", SimpleType.INT)), + newFunctionDeclaration("error", newGlobalOverload("error_overload", SimpleType.INT)), + newFunctionDeclaration( + "neg", + newGlobalOverload("neg_int", SimpleType.INT, SimpleType.INT), + newGlobalOverload("neg_double", SimpleType.DOUBLE, SimpleType.DOUBLE)), + newFunctionDeclaration( + "concat", + newGlobalOverload( + "concat_bytes_bytes", SimpleType.BYTES, SimpleType.BYTES, SimpleType.BYTES), + newMemberOverload( + "bytes_concat_bytes", SimpleType.BYTES, SimpleType.BYTES, SimpleType.BYTES))) .addMessageTypes(TestAllTypes.getDescriptor()) .addLibraries(CelExtensions.optional()) .build(); + /** + * Configure dispatcher for testing purposes. This is done manually here, but this should be + * driven by the top-level runtime APIs in the future + */ + private static DefaultDispatcher newDispatcher() { + DefaultDispatcher.Builder builder = DefaultDispatcher.newBuilder(); + + // Subsetted StdLib + addBindings( + builder, Operator.INDEX.getFunction(), fromStandardFunction(IndexOperator.create())); + addBindings( + builder, + Operator.LOGICAL_NOT.getFunction(), + fromStandardFunction(LogicalNotOperator.create())); + addBindings(builder, Operator.ADD.getFunction(), fromStandardFunction(AddOperator.create())); + addBindings( + builder, Operator.GREATER.getFunction(), fromStandardFunction(GreaterOperator.create())); + addBindings( + builder, + Operator.GREATER_EQUALS.getFunction(), + fromStandardFunction(GreaterEqualsOperator.create())); + addBindings(builder, Operator.LESS.getFunction(), fromStandardFunction(LessOperator.create())); + addBindings( + builder, Operator.DIVIDE.getFunction(), fromStandardFunction(DivideOperator.create())); + addBindings( + builder, Operator.EQUALS.getFunction(), fromStandardFunction(EqualsOperator.create())); + addBindings( + builder, + Operator.NOT_STRICTLY_FALSE.getFunction(), + fromStandardFunction(NotStrictlyFalseFunction.create())); + + // Custom functions + addBindings( + builder, + "zero", + CelFunctionBinding.from("zero_overload", ImmutableList.of(), (unused) -> 0L)); + addBindings( + builder, + "error", + CelFunctionBinding.from( + "error_overload", + ImmutableList.of(), + (unused) -> { + throw new IllegalArgumentException("Intentional error"); + })); + addBindings( + builder, + "neg", + CelFunctionBinding.from("neg_int", Long.class, arg -> -arg), + CelFunctionBinding.from("neg_double", Double.class, arg -> -arg)); + addBindings( + builder, + "concat", + CelFunctionBinding.from( + "concat_bytes_bytes", + CelByteString.class, + CelByteString.class, + ProgramPlannerTest::concatenateByteArrays), + CelFunctionBinding.from( + "bytes_concat_bytes", + CelByteString.class, + CelByteString.class, + ProgramPlannerTest::concatenateByteArrays)); + + return builder.build(); + } + + private static void addBindings( + DefaultDispatcher.Builder builder, + String functionName, + CelFunctionBinding... functionBindings) { + addBindings(builder, functionName, ImmutableSet.copyOf(functionBindings)); + } + + private static void addBindings( + DefaultDispatcher.Builder builder, + String functionName, + ImmutableCollection overloadBindings) { + if (overloadBindings.isEmpty()) { + throw new IllegalArgumentException("Invalid bindings"); + } + // TODO: Runtime top-level APIs currently does not allow grouping overloads with + // the function name. This capability will have to be added. + if (overloadBindings.size() == 1) { + CelFunctionBinding singleBinding = Iterables.getOnlyElement(overloadBindings); + builder.addOverload( + functionName, + singleBinding.getArgTypes(), + singleBinding.isStrict(), + args -> guardedOp(functionName, args, singleBinding)); + } else { + overloadBindings.forEach( + overload -> + builder.addOverload( + overload.getOverloadId(), + overload.getArgTypes(), + overload.isStrict(), + args -> guardedOp(functionName, args, overload))); + + // Setup dynamic dispatch + CelFunctionOverload dynamicDispatchDef = + args -> { + for (CelFunctionBinding overload : overloadBindings) { + if (CelResolvedOverload.canHandle( + args, overload.getArgTypes(), overload.isStrict())) { + return overload.getDefinition().apply(args); + } + } + + throw new IllegalArgumentException( + "No matching overload for function: " + functionName); + }; + + boolean allOverloadsStrict = overloadBindings.stream().allMatch(CelFunctionBinding::isStrict); + builder.addOverload( + functionName, ImmutableList.of(), /* isStrict= */ allOverloadsStrict, dynamicDispatchDef); + } + } + + /** Creates an invocation guard around the overload definition. */ + private static Object guardedOp( + String functionName, Object[] args, CelFunctionBinding singleBinding) + throws CelEvaluationException { + if (!CelResolvedOverload.canHandle( + args, singleBinding.getArgTypes(), singleBinding.isStrict())) { + throw new IllegalArgumentException("No matching overload for function: " + functionName); + } + + return singleBinding.getDefinition().apply(args); + } + @TestParameter boolean isParseOnly; @Test @@ -206,6 +377,89 @@ public void planCreateStruct_withFields() throws Exception { .isEqualTo(TestAllTypes.newBuilder().setSingleString("foo").setSingleBool(true).build()); } + @Test + public void plan_call_zeroArgs() throws Exception { + CelAbstractSyntaxTree ast = compile("zero()"); + Program program = PLANNER.plan(ast); + + Long result = (Long) program.eval(); + + assertThat(result).isEqualTo(0L); + } + + @Test + public void plan_call_throws() throws Exception { + CelAbstractSyntaxTree ast = compile("error()"); + Program program = PLANNER.plan(ast); + + CelEvaluationException e = assertThrows(CelEvaluationException.class, program::eval); + assertThat(e).hasMessageThat().contains("evaluation error: Intentional error"); + assertThat(e).hasCauseThat().isInstanceOf(IllegalArgumentException.class); + } + + @Test + public void plan_call_oneArg_int() throws Exception { + CelAbstractSyntaxTree ast = compile("neg(1)"); + Program program = PLANNER.plan(ast); + + Long result = (Long) program.eval(); + + assertThat(result).isEqualTo(-1L); + } + + @Test + public void plan_call_oneArg_double() throws Exception { + CelAbstractSyntaxTree ast = compile("neg(2.5)"); + Program program = PLANNER.plan(ast); + + Double result = (Double) program.eval(); + + assertThat(result).isEqualTo(-2.5d); + } + + @Test + public void plan_call_twoArgs_global() throws Exception { + CelAbstractSyntaxTree ast = compile("concat(b'abc', b'def')"); + Program program = PLANNER.plan(ast); + + CelByteString result = (CelByteString) program.eval(); + + assertThat(result).isEqualTo(CelByteString.of("abcdef".getBytes(UTF_8))); + } + + @Test + public void plan_call_twoArgs_receiver() throws Exception { + CelAbstractSyntaxTree ast = compile("b'abc'.concat(b'def')"); + Program program = PLANNER.plan(ast); + + CelByteString result = (CelByteString) program.eval(); + + assertThat(result).isEqualTo(CelByteString.of("abcdef".getBytes(UTF_8))); + } + + @Test + public void plan_call_mapIndex() throws Exception { + CelAbstractSyntaxTree ast = compile("map_var['key'][1]"); + Program program = PLANNER.plan(ast); + ImmutableMap mapVarPayload = ImmutableMap.of("key", ImmutableList.of(1L, 2L)); + + Long result = (Long) program.eval(ImmutableMap.of("map_var", mapVarPayload)); + + assertThat(result).isEqualTo(2L); + } + + @Test + public void plan_call_noMatchingOverload_throws() throws Exception { + CelAbstractSyntaxTree ast = compile("concat(b'abc', dyn_var)"); + Program program = PLANNER.plan(ast); + + CelEvaluationException e = + assertThrows( + CelEvaluationException.class, + () -> program.eval(ImmutableMap.of("dyn_var", "Impossible Overload"))); + assertThat(e).hasMessageThat().contains("No matching overload for function: concat"); + } + private CelAbstractSyntaxTree compile(String expression) throws Exception { CelAbstractSyntaxTree ast = CEL_COMPILER.parse(expression).getAst(); if (isParseOnly) { @@ -215,6 +469,23 @@ private CelAbstractSyntaxTree compile(String expression) throws Exception { return CEL_COMPILER.check(ast).getAst(); } + private static CelByteString concatenateByteArrays(CelByteString bytes1, CelByteString bytes2) { + if (bytes1.isEmpty()) { + return bytes2; + } + + if (bytes2.isEmpty()) { + return bytes1; + } + + return bytes1.concat(bytes2); + } + + private static ImmutableSet fromStandardFunction( + CelStandardFunction standardFunction) { + return standardFunction.newFunctionBindings(CEL_OPTIONS, RUNTIME_EQUALITY); + } + @SuppressWarnings("ImmutableEnumChecker") // Test only private enum ConstantTestCase { NULL("null", NullValue.NULL_VALUE),