diff --git a/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java b/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java index 3865430ec..94b7ad313 100644 --- a/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java +++ b/checker/src/main/java/dev/cel/checker/CelStandardDeclarations.java @@ -1488,7 +1488,7 @@ public CelFunctionDecl functionDecl() { return celFunctionDecl; } - String functionName() { + public String functionName() { return functionName; } diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel index 7674870a8..6057b3105 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/BUILD.bazel @@ -19,12 +19,14 @@ java_library( ":default_optimizer_constants", "//:auto_value", "//bundle:cel", + "//checker:standard_decl", "//common:cel_ast", "//common:cel_source", "//common:compiler_common", "//common:mutable_ast", "//common/ast", "//common/ast:mutable_expr", + "//common/internal:date_time_helpers", "//common/navigation:mutable_navigation", "//common/types", "//extensions:optional_library", diff --git a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java index b40250fe6..294589af2 100644 --- a/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java +++ b/optimizer/src/main/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizer.java @@ -16,6 +16,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.onlyElement; +import static dev.cel.checker.CelStandardDeclarations.StandardFunction.DURATION; +import static dev.cel.checker.CelStandardDeclarations.StandardFunction.TIMESTAMP; import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; @@ -36,6 +38,7 @@ import dev.cel.common.ast.CelMutableExpr.CelMutableMap; import dev.cel.common.ast.CelMutableExpr.CelMutableStruct; import dev.cel.common.ast.CelMutableExprConverter; +import dev.cel.common.internal.DateTimeHelpers; import dev.cel.common.navigation.CelNavigableMutableAst; import dev.cel.common.navigation.CelNavigableMutableExpr; import dev.cel.common.types.SimpleType; @@ -45,6 +48,8 @@ import dev.cel.optimizer.CelOptimizationException; import dev.cel.parser.Operator; import dev.cel.runtime.CelEvaluationException; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -142,6 +147,14 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { return false; } + // Timestamps/durations in CEL are calls, but they are effectively treated as literals. + // Expressions like timestamp(123) cannot be folded directly, but arithmetics involving + // timestamps can be optimized. + // Ex: timestamp(123) - timestamp(100) = duration("23s") + if (isCallTimestampOrDuration(navigableExpr.expr().call())) { + return false; + } + CelMutableCall mutableCall = navigableExpr.expr().call(); String functionName = mutableCall.function(); @@ -197,14 +210,14 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) { return navigableExpr .allNodes() - .allMatch( - node -> { - if (node.getKind().equals(Kind.CALL)) { - return foldableFunctions.contains(node.expr().call().function()); - } - - return true; - }); + .filter(node -> node.getKind().equals(Kind.CALL)) + .map(node -> node.expr().call()) + .allMatch(call -> foldableFunctions.contains(call.function())); + } + + private static boolean isCallTimestampOrDuration(CelMutableCall call) { + return call.function().equals(TIMESTAMP.functionName()) + || call.function().equals(DURATION.functionName()); } private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr) { @@ -318,6 +331,22 @@ private Optional maybeAdaptEvaluatedResult(Object result) { } return Optional.of(CelMutableExpr.ofMap(CelMutableMap.create(mapEntries))); + } else if (result instanceof Duration) { + String durationStrArg = DateTimeHelpers.toString((Duration) result); + CelMutableCall durationCall = + CelMutableCall.create( + DURATION.functionName(), + CelMutableExpr.ofConstant(CelConstant.ofValue(durationStrArg))); + + return Optional.of(CelMutableExpr.ofCall(durationCall)); + } else if (result instanceof Instant) { + String timestampStrArg = result.toString(); + CelMutableCall timestampCall = + CelMutableCall.create( + TIMESTAMP.functionName(), + CelMutableExpr.ofConstant(CelConstant.ofValue(timestampStrArg))); + + return Optional.of(CelMutableExpr.ofCall(timestampCall)); } // Evaluated result cannot be folded (e.g: unknowns) diff --git a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java index 4efbeb1c2..05e2e5457 100644 --- a/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java +++ b/optimizer/src/test/java/dev/cel/optimizer/optimizers/ConstantFoldingOptimizerTest.java @@ -46,6 +46,11 @@ @RunWith(TestParameterInjector.class) public class ConstantFoldingOptimizerTest { + private static final CelOptions CEL_OPTIONS = + CelOptions.current() + .enableTimestampEpoch(true) + .evaluateCanonicalTypesToNativeValues(true) + .build(); private static final Cel CEL = CelFactory.standardCelBuilder() .addVar("x", SimpleType.DYN) @@ -60,19 +65,20 @@ public class ConstantFoldingOptimizerTest { CelFunctionBinding.from("get_true_overload", ImmutableList.of(), unused -> true)) .addMessageTypes(TestAllTypes.getDescriptor()) .setContainer(CelContainer.ofName("cel.expr.conformance.proto3")) + .setOptions(CEL_OPTIONS) .addCompilerLibraries( CelExtensions.bindings(), CelOptionalLibrary.INSTANCE, - CelExtensions.math(CelOptions.DEFAULT), + CelExtensions.math(CEL_OPTIONS), CelExtensions.strings(), - CelExtensions.sets(CelOptions.DEFAULT), - CelExtensions.encoders(CelOptions.DEFAULT)) + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) .addRuntimeLibraries( CelOptionalLibrary.INSTANCE, - CelExtensions.math(CelOptions.DEFAULT), + CelExtensions.math(CEL_OPTIONS), CelExtensions.strings(), - CelExtensions.sets(CelOptions.DEFAULT), - CelExtensions.encoders(CelOptions.DEFAULT)) + CelExtensions.sets(CEL_OPTIONS), + CelExtensions.encoders(CEL_OPTIONS)) .build(); private static final CelOptimizer CEL_OPTIMIZER = @@ -211,6 +217,15 @@ public class ConstantFoldingOptimizerTest { @TestParameters("{source: '42 != 42', expected: 'false'}") @TestParameters("{source: '[\"foo\",\"bar\"] == [\"foo\",\"bar\"]', expected: 'true'}") @TestParameters("{source: '[\"bar\",\"foo\"] == [\"foo\",\"bar\"]', expected: 'false'}") + @TestParameters("{source: 'duration(\"1h\") - duration(\"60m\")', expected: 'duration(\"0s\")'}") + @TestParameters( + "{source: 'duration(\"2h23m42s12ms42us92ns\") + duration(\"129481231298125ns\")', expected:" + + " 'duration(\"138103.243340217s\")'}") + @TestParameters( + "{source: 'timestamp(900000) - timestamp(100)', expected: 'duration(\"899900s\")'}") + @TestParameters( + "{source: 'timestamp(\"2000-01-01T00:02:03.2123Z\") + duration(\"25h2m32s42ms53us29ns\")'," + + " expected: 'timestamp(\"2000-01-02T01:04:35.254353029Z\")'}") // TODO: Support folding lists with mixed types. This requires mutable lists. // @TestParameters("{source: 'dyn([1]) + [1.0]'}") public void constantFold_success(String source, String expected) throws Exception { @@ -348,6 +363,8 @@ public void constantFold_macros_withoutMacroCallMetadata(String source) throws E @TestParameters("{source: 'get_true() == true'}") @TestParameters("{source: 'x == x'}") @TestParameters("{source: 'x == 42'}") + @TestParameters("{source: 'timestamp(100)'}") + @TestParameters("{source: 'duration(\"1h\")'}") public void constantFold_noOp(String source) throws Exception { CelAbstractSyntaxTree ast = CEL.compile(source).getAst();