diff --git a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java index ff9e31432..b87967d0e 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelBindingsExtensionsTest.java @@ -22,40 +22,51 @@ import com.google.testing.junit.testparameterinjector.TestParameter; import com.google.testing.junit.testparameterinjector.TestParameterInjector; import com.google.testing.junit.testparameterinjector.TestParameters; +import dev.cel.bundle.Cel; import dev.cel.common.CelAbstractSyntaxTree; import dev.cel.common.CelFunctionDecl; import dev.cel.common.CelOptions; import dev.cel.common.CelOverloadDecl; import dev.cel.common.CelValidationException; +import dev.cel.common.exceptions.CelDivideByZeroException; import dev.cel.common.types.SimpleType; import dev.cel.common.types.StructTypeReference; -import dev.cel.compiler.CelCompiler; -import dev.cel.compiler.CelCompilerFactory; import dev.cel.expr.conformance.proto3.TestAllTypes; import dev.cel.parser.CelMacro; import dev.cel.parser.CelStandardMacro; +import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.CelFunctionBinding; -import dev.cel.runtime.CelRuntime; -import dev.cel.runtime.CelRuntimeFactory; +import dev.cel.testing.CelRuntimeFlavor; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Assume; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @RunWith(TestParameterInjector.class) public final class CelBindingsExtensionsTest { - private static final CelCompiler COMPILER = - CelCompilerFactory.standardCelCompilerBuilder() - .setStandardMacros(CelStandardMacro.STANDARD_MACROS) - .addLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) - .build(); - - private static final CelRuntime RUNTIME = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addLibraries(CelOptionalLibrary.INSTANCE) - .build(); + @TestParameter public CelRuntimeFlavor runtimeFlavor; + @TestParameter public boolean isParseOnly; + + private Cel cel; + + @Before + public void setUp() { + // Legacy runtime does not support parsed-only evaluation mode. + Assume.assumeFalse(runtimeFlavor.equals(CelRuntimeFlavor.LEGACY) && isParseOnly); + cel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .setStandardMacros(CelStandardMacro.STANDARD_MACROS) + .addCompilerLibraries(CelOptionalLibrary.INSTANCE, CelExtensions.bindings()) + .addRuntimeLibraries(CelOptionalLibrary.INSTANCE) + .build(); + } @Test public void library() { @@ -93,9 +104,7 @@ private enum BindingTestCase { @Test public void binding_success(@TestParameter BindingTestCase testCase) throws Exception { - CelAbstractSyntaxTree ast = COMPILER.compile(testCase.source).getAst(); - CelRuntime.Program program = RUNTIME.createProgram(ast); - boolean evaluatedResult = (boolean) program.eval(); + boolean evaluatedResult = (boolean) eval(testCase.source); assertThat(evaluatedResult).isTrue(); } @@ -103,9 +112,11 @@ public void binding_success(@TestParameter BindingTestCase testCase) throws Exce @Test @TestParameters("{expr: 'false.bind(false, false, false)'}") public void binding_nonCelNamespace_success(String expr) throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.bindings()) + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "bind", @@ -116,18 +127,16 @@ public void binding_nonCelNamespace_success(String expr) throws Exception { SimpleType.BOOL, SimpleType.BOOL, SimpleType.BOOL))) - .build(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "bool_bind_bool_bool_bool", - Arrays.asList(Boolean.class, Boolean.class, Boolean.class, Boolean.class), - (args) -> true)) + CelFunctionBinding.fromOverloads( + "bind", + CelFunctionBinding.from( + "bool_bind_bool_bool_bool", + Arrays.asList(Boolean.class, Boolean.class, Boolean.class, Boolean.class), + (args) -> true))) .build(); - CelAbstractSyntaxTree ast = celCompiler.compile(expr).getAst(); - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = (boolean) eval(customCel, expr); assertThat(result).isTrue(); } @@ -135,7 +144,7 @@ public void binding_nonCelNamespace_success(String expr) throws Exception { @TestParameters("{expr: 'cel.bind(bad.name, true, bad.name)'}") public void binding_throwsCompilationException(String expr) throws Exception { CelValidationException e = - assertThrows(CelValidationException.class, () -> COMPILER.compile(expr).getAst()); + assertThrows(CelValidationException.class, () -> cel.compile(expr).getAst()); assertThat(e).hasMessageThat().contains("cel.bind() variable name must be a simple identifier"); } @@ -143,70 +152,76 @@ public void binding_throwsCompilationException(String expr) throws Exception { @Test @SuppressWarnings("Immutable") // Test only public void lazyBinding_bindingVarNeverReferenced() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() + + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setStandardMacros(CelStandardMacro.HAS) .addMessageTypes(TestAllTypes.getDescriptor()) .addVar("msg", StructTypeReference.create(TestAllTypes.getDescriptor().getFullName())) - .addLibraries(CelExtensions.bindings()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() - .addMessageTypes(TestAllTypes.getDescriptor()) .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler.compile("cel.bind(t, get_true(), has(msg.single_int64) ? t : false)").getAst(); - boolean result = (boolean) - celRuntime - .createProgram(ast) - .eval(ImmutableMap.of("msg", TestAllTypes.getDefaultInstance())); + eval( + customCel, + "cel.bind(t, get_true(), has(msg.single_int64) ? t : false)", + ImmutableMap.of("msg", TestAllTypes.getDefaultInstance())); assertThat(result).isFalse(); assertThat(invocation.get()).isEqualTo(0); } + @Test + public void lazyBinding_throwsEvaluationException() throws Exception { + CelEvaluationException e = + assertThrows(CelEvaluationException.class, () -> eval(cel, "cel.bind(t, 1 / 0, t)")); + + assertThat(e).hasMessageThat().contains("/ by zero"); + assertThat(e).hasCauseThat().isInstanceOf(CelDivideByZeroException.class); + } + @Test @SuppressWarnings("Immutable") // Test only public void lazyBinding_accuInitEvaluatedOnce() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.bindings()) + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler.compile("cel.bind(t, get_true(), t && t && t && t)").getAst(); - - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = (boolean) eval(customCel, "cel.bind(t, get_true(), t && t && t && t)"); assertThat(result).isTrue(); assertThat(invocation.get()).isEqualTo(1); @@ -215,32 +230,32 @@ public void lazyBinding_accuInitEvaluatedOnce() throws Exception { @Test @SuppressWarnings("Immutable") // Test only public void lazyBinding_withNestedBinds() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() - .addLibraries(CelExtensions.bindings()) + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler - .compile("cel.bind(t1, get_true(), cel.bind(t2, get_true(), t1 && t2 && t1 && t2))") - .getAst(); - - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + boolean result = + (boolean) + eval( + customCel, + "cel.bind(t1, get_true(), cel.bind(t2, get_true(), t1 && t2 && t1 && t2))"); assertThat(result).isTrue(); assertThat(invocation.get()).isEqualTo(2); @@ -249,32 +264,31 @@ public void lazyBinding_withNestedBinds() throws Exception { @Test @SuppressWarnings({"Immutable", "unchecked"}) // Test only public void lazyBinding_boundAttributeInComprehension() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setStandardMacros(CelStandardMacro.MAP) - .addLibraries(CelExtensions.bindings()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler.compile("cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))").getAst(); - - List result = (List) celRuntime.createProgram(ast).eval(); + List result = + (List) eval(customCel, "cel.bind(x, get_true(), [1,2,3].map(y, y < 0 || x))"); assertThat(result).containsExactly(true, true, true); assertThat(invocation.get()).isEqualTo(1); @@ -283,38 +297,55 @@ public void lazyBinding_boundAttributeInComprehension() throws Exception { @Test @SuppressWarnings({"Immutable"}) // Test only public void lazyBinding_boundAttributeInNestedComprehension() throws Exception { - CelCompiler celCompiler = - CelCompilerFactory.standardCelCompilerBuilder() + AtomicInteger invocation = new AtomicInteger(); + Cel customCel = + runtimeFlavor + .builder() + .setOptions(CelOptions.current().enableHeterogeneousNumericComparisons(true).build()) .setStandardMacros(CelStandardMacro.EXISTS) - .addLibraries(CelExtensions.bindings()) + .addCompilerLibraries(CelExtensions.bindings()) .addFunctionDeclarations( CelFunctionDecl.newFunctionDeclaration( "get_true", CelOverloadDecl.newGlobalOverload("get_true_overload", SimpleType.BOOL))) - .build(); - AtomicInteger invocation = new AtomicInteger(); - CelRuntime celRuntime = - CelRuntimeFactory.standardCelRuntimeBuilder() .addFunctionBindings( - CelFunctionBinding.from( - "get_true_overload", - ImmutableList.of(), - arg -> { - invocation.getAndIncrement(); - return true; - })) + CelFunctionBinding.fromOverloads( + "get_true", + CelFunctionBinding.from( + "get_true_overload", + ImmutableList.of(), + arg -> { + invocation.getAndIncrement(); + return true; + }))) .build(); - CelAbstractSyntaxTree ast = - celCompiler - .compile( + boolean result = + (boolean) + eval( + customCel, "cel.bind(x, get_true(), [1,2,3].exists(unused, x && " - + "['a','b','c'].exists(unused_2, x)))") - .getAst(); - - boolean result = (boolean) celRuntime.createProgram(ast).eval(); + + "['a','b','c'].exists(unused_2, x)))"); assertThat(result).isTrue(); assertThat(invocation.get()).isEqualTo(1); } + + private Object eval(Cel cel, String expression) throws Exception { + return eval(cel, expression, ImmutableMap.of()); + } + + private Object eval(Cel cel, String expression, Map variables) throws Exception { + CelAbstractSyntaxTree ast; + if (isParseOnly) { + ast = cel.parse(expression).getAst(); + } else { + ast = cel.compile(expression).getAst(); + } + return cel.createProgram(ast).eval(variables); + } + + private Object eval(String expression) throws Exception { + return eval(this.cel, expression, ImmutableMap.of()); + } } diff --git a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel index cb2ad5a82..824c918d8 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel +++ b/runtime/src/main/java/dev/cel/runtime/planner/BUILD.bazel @@ -364,6 +364,7 @@ java_library( deps = [ ":activation_wrapper", ":planned_interpretable", + "//common/exceptions:runtime_exception", "//runtime:accumulated_unknowns", "//runtime:concatenated_list_view", "//runtime:evaluation_exception", diff --git a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java index 2631bf0b9..2eb30671e 100644 --- a/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java +++ b/runtime/src/main/java/dev/cel/runtime/planner/EvalFold.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.Immutable; +import dev.cel.common.exceptions.CelRuntimeException; import dev.cel.runtime.AccumulatedUnknowns; import dev.cel.runtime.CelEvaluationException; import dev.cel.runtime.ConcatenatedListView; @@ -77,8 +78,7 @@ public Object eval(GlobalResolver resolver, ExecutionFrame frame) throws CelEval if (iterRangeRaw instanceof AccumulatedUnknowns) { return iterRangeRaw; } - Folder folder = new Folder(resolver, accuVar, iterVar, iterVar2); - folder.accuVal = maybeWrapAccumulator(accuInit.eval(folder, frame)); + Folder folder = new Folder(resolver, frame, accuInit, accuVar, iterVar, iterVar2); Object result; if (iterRangeRaw instanceof Map) { @@ -104,11 +104,14 @@ private Object evalMap(Map iterRange, Folder folder, ExecutionFrame frame) boolean cond = (boolean) condition.eval(folder, frame); if (!cond) { + folder.computeResult = true; return result.eval(folder, frame); } folder.accuVal = loopStep.eval(folder, frame); + folder.initialized = true; } + folder.computeResult = true; return result.eval(folder, frame); } @@ -127,12 +130,15 @@ private Object evalList(Collection iterRange, Folder folder, ExecutionFrame f boolean cond = (boolean) condition.eval(folder, frame); if (!cond) { + folder.computeResult = true; return result.eval(folder, frame); } folder.accuVal = loopStep.eval(folder, frame); + folder.initialized = true; index++; } + folder.computeResult = true; return result.eval(folder, frame); } @@ -155,6 +161,8 @@ private static Object maybeUnwrapAccumulator(Object val) { private static class Folder implements ActivationWrapper { private final GlobalResolver resolver; + private final ExecutionFrame frame; + private final PlannedInterpretable accuInit; private final String accuVar; private final String iterVar; private final String iterVar2; @@ -162,9 +170,19 @@ private static class Folder implements ActivationWrapper { private Object iterVarVal; private Object iterVar2Val; private Object accuVal; - - private Folder(GlobalResolver resolver, String accuVar, String iterVar, String iterVar2) { + private boolean initialized = false; + private boolean computeResult = false; + + private Folder( + GlobalResolver resolver, + ExecutionFrame frame, + PlannedInterpretable accuInit, + String accuVar, + String iterVar, + String iterVar2) { this.resolver = resolver; + this.frame = frame; + this.accuInit = accuInit; this.accuVar = accuVar; this.iterVar = iterVar; this.iterVar2 = iterVar2; @@ -183,18 +201,34 @@ public boolean isLocallyBound(String name) { @Override public @Nullable Object resolve(String name) { if (name.equals(accuVar)) { + if (!initialized) { + initialized = true; + try { + accuVal = maybeWrapAccumulator(accuInit.eval(resolver, frame)); + } catch (CelEvaluationException e) { + throw new LazyEvaluationRuntimeException(e); + } + } return accuVal; } - if (name.equals(iterVar)) { - return this.iterVarVal; - } + if (!computeResult) { + if (name.equals(iterVar)) { + return this.iterVarVal; + } - if (name.equals(iterVar2)) { - return this.iterVar2Val; + if (name.equals(iterVar2)) { + return this.iterVar2Val; + } } return resolver.resolve(name); } } + + private static class LazyEvaluationRuntimeException extends CelRuntimeException { + private LazyEvaluationRuntimeException(CelEvaluationException cause) { + super(cause, cause.getErrorCode()); + } + } }