diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java index 9aab4a3e3210f..d3259b9604717 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/EvalBenchmark.java @@ -27,6 +27,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -71,12 +72,11 @@ public class EvalBenchmark { BigArrays.NON_RECYCLING_INSTANCE ); + private static final FoldContext FOLD_CONTEXT = FoldContext.small(); + private static final int BLOCK_LENGTH = 8 * 1024; - static final DriverContext driverContext = new DriverContext( - BigArrays.NON_RECYCLING_INSTANCE, - BlockFactory.getInstance(new NoopCircuitBreaker("noop"), BigArrays.NON_RECYCLING_INSTANCE) - ); + static final DriverContext driverContext = new DriverContext(BigArrays.NON_RECYCLING_INSTANCE, blockFactory); static { // Smoke test all the expected values and force loading subclasses more like prod @@ -114,11 +114,12 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { return switch (operation) { case "abs" -> { FieldAttribute longField = longField(); - yield EvalMapper.toEvaluator(new Abs(Source.EMPTY, longField), layout(longField)).get(driverContext); + yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Abs(Source.EMPTY, longField), layout(longField)).get(driverContext); } case "add" -> { FieldAttribute longField = longField(); yield EvalMapper.toEvaluator( + FOLD_CONTEXT, new Add(Source.EMPTY, longField, new Literal(Source.EMPTY, 1L, DataType.LONG)), layout(longField) ).get(driverContext); @@ -126,6 +127,7 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { case "add_double" -> { FieldAttribute doubleField = doubleField(); yield EvalMapper.toEvaluator( + FOLD_CONTEXT, new Add(Source.EMPTY, doubleField, new Literal(Source.EMPTY, 1D, DataType.DOUBLE)), layout(doubleField) ).get(driverContext); @@ -140,7 +142,8 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { lhs = new Add(Source.EMPTY, lhs, new Literal(Source.EMPTY, 1L, DataType.LONG)); rhs = new Add(Source.EMPTY, rhs, new Literal(Source.EMPTY, 1L, DataType.LONG)); } - yield EvalMapper.toEvaluator(new Case(Source.EMPTY, condition, List.of(lhs, rhs)), layout(f1, f2)).get(driverContext); + yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Case(Source.EMPTY, condition, List.of(lhs, rhs)), layout(f1, f2)) + .get(driverContext); } case "date_trunc" -> { FieldAttribute timestamp = new FieldAttribute( @@ -149,6 +152,7 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { new EsField("timestamp", DataType.DATETIME, Map.of(), true) ); yield EvalMapper.toEvaluator( + FOLD_CONTEXT, new DateTrunc(Source.EMPTY, new Literal(Source.EMPTY, Duration.ofHours(24), DataType.TIME_DURATION), timestamp), layout(timestamp) ).get(driverContext); @@ -156,6 +160,7 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { case "equal_to_const" -> { FieldAttribute longField = longField(); yield EvalMapper.toEvaluator( + FOLD_CONTEXT, new Equals(Source.EMPTY, longField, new Literal(Source.EMPTY, 100_000L, DataType.LONG)), layout(longField) ).get(driverContext); @@ -163,21 +168,21 @@ private static EvalOperator.ExpressionEvaluator evaluator(String operation) { case "long_equal_to_long" -> { FieldAttribute lhs = longField(); FieldAttribute rhs = longField(); - yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext); + yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext); } case "long_equal_to_int" -> { FieldAttribute lhs = longField(); FieldAttribute rhs = intField(); - yield EvalMapper.toEvaluator(new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext); + yield EvalMapper.toEvaluator(FOLD_CONTEXT, new Equals(Source.EMPTY, lhs, rhs), layout(lhs, rhs)).get(driverContext); } case "mv_min", "mv_min_ascending" -> { FieldAttribute longField = longField(); - yield EvalMapper.toEvaluator(new MvMin(Source.EMPTY, longField), layout(longField)).get(driverContext); + yield EvalMapper.toEvaluator(FOLD_CONTEXT, new MvMin(Source.EMPTY, longField), layout(longField)).get(driverContext); } case "rlike" -> { FieldAttribute keywordField = keywordField(); RLike rlike = new RLike(Source.EMPTY, keywordField, new RLikePattern(".ar")); - yield EvalMapper.toEvaluator(rlike, layout(keywordField)).get(driverContext); + yield EvalMapper.toEvaluator(FOLD_CONTEXT, rlike, layout(keywordField)).get(driverContext); } default -> throw new UnsupportedOperationException(); }; diff --git a/docs/changelog/118602.yaml b/docs/changelog/118602.yaml new file mode 100644 index 0000000000000..a75c5dcf11da3 --- /dev/null +++ b/docs/changelog/118602.yaml @@ -0,0 +1,5 @@ +pr: 118602 +summary: Limit memory usage of `fold` +area: ES|QL +type: bug +issues: [] diff --git a/docs/reference/esql/functions/kibana/definition/bucket.json b/docs/reference/esql/functions/kibana/definition/bucket.json index 3d96de05c8407..f9c7f2f27d6f9 100644 --- a/docs/reference/esql/functions/kibana/definition/bucket.json +++ b/docs/reference/esql/functions/kibana/definition/bucket.json @@ -1599,7 +1599,7 @@ "FROM sample_data \n| WHERE @timestamp >= NOW() - 1 day and @timestamp < NOW()\n| STATS COUNT(*) BY bucket = BUCKET(@timestamp, 25, NOW() - 1 day, NOW())", "FROM employees\n| WHERE hire_date >= \"1985-01-01T00:00:00Z\" AND hire_date < \"1986-01-01T00:00:00Z\"\n| STATS AVG(salary) BY bucket = BUCKET(hire_date, 20, \"1985-01-01T00:00:00Z\", \"1986-01-01T00:00:00Z\")\n| SORT bucket", "FROM employees\n| STATS s1 = b1 + 1, s2 = BUCKET(salary / 1000 + 999, 50.) + 2 BY b1 = BUCKET(salary / 100 + 99, 50.), b2 = BUCKET(salary / 1000 + 999, 50.)\n| SORT b1, b2\n| KEEP s1, b1, s2, b2", - "FROM employees \n| STATS dates = VALUES(birth_date) BY b = BUCKET(birth_date + 1 HOUR, 1 YEAR) - 1 HOUR\n| EVAL d_count = MV_COUNT(dates)\n| SORT d_count\n| LIMIT 3" + "FROM employees\n| STATS dates = MV_SORT(VALUES(birth_date)) BY b = BUCKET(birth_date + 1 HOUR, 1 YEAR) - 1 HOUR\n| EVAL d_count = MV_COUNT(dates)\n| SORT d_count, b\n| LIMIT 3" ], "preview" : false, "snapshot_only" : false diff --git a/docs/reference/esql/functions/kibana/definition/match_operator.json b/docs/reference/esql/functions/kibana/definition/match_operator.json index 44233bbddb653..c8cbf1cf9d966 100644 --- a/docs/reference/esql/functions/kibana/definition/match_operator.json +++ b/docs/reference/esql/functions/kibana/definition/match_operator.json @@ -2,7 +2,7 @@ "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", "type" : "operator", "name" : "match_operator", - "description" : "Performs a <> on the specified field. Returns true if the provided query matches the row.", + "description" : "Use `MATCH` to perform a <> on the specified field.\nUsing `MATCH` is equivalent to using the `match` query in the Elasticsearch Query DSL.\n\nMatch can be used on text fields, as well as other field types like boolean, dates, and numeric types.\n\nFor a simplified syntax, you can use the <> `:` operator instead of `MATCH`.\n\n`MATCH` returns true if the provided query matches the row.", "signatures" : [ { "params" : [ diff --git a/docs/reference/esql/functions/kibana/docs/match_operator.md b/docs/reference/esql/functions/kibana/docs/match_operator.md index b0b6196798087..7681c2d1ce231 100644 --- a/docs/reference/esql/functions/kibana/docs/match_operator.md +++ b/docs/reference/esql/functions/kibana/docs/match_operator.md @@ -3,7 +3,14 @@ This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../READ --> ### MATCH_OPERATOR -Performs a <> on the specified field. Returns true if the provided query matches the row. +Use `MATCH` to perform a <> on the specified field. +Using `MATCH` is equivalent to using the `match` query in the Elasticsearch Query DSL. + +Match can be used on text fields, as well as other field types like boolean, dates, and numeric types. + +For a simplified syntax, you can use the <> `:` operator instead of `MATCH`. + +`MATCH` returns true if the provided query matches the row. ``` FROM books diff --git a/server/src/main/java/org/elasticsearch/common/settings/Setting.java b/server/src/main/java/org/elasticsearch/common/settings/Setting.java index 32e8999f6bb26..9c4468c962c0c 100644 --- a/server/src/main/java/org/elasticsearch/common/settings/Setting.java +++ b/server/src/main/java/org/elasticsearch/common/settings/Setting.java @@ -1740,7 +1740,7 @@ public static > Setting enumSetting( * * @param key the key for the setting * @param defaultValue the default value for this setting - * @param properties properties properties for this setting like scope, filtering... + * @param properties properties for this setting like scope, filtering... * @return the setting object */ public static Setting memorySizeSetting(String key, ByteSizeValue defaultValue, Property... properties) { diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index ace3db377664c..958132b3e4076 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -194,6 +194,16 @@ private void assertCircuitBreaks(ThrowingRunnable r) throws IOException { ); } + private void assertFoldCircuitBreaks(ThrowingRunnable r) throws IOException { + ResponseException e = expectThrows(ResponseException.class, r); + Map map = responseAsMap(e.getResponse()); + logger.info("expected fold circuit breaking {}", map); + assertMap( + map, + matchesMap().entry("status", 400).entry("error", matchesMap().extraOk().entry("type", "fold_too_much_memory_exception")) + ); + } + private void assertParseFailure(ThrowingRunnable r) throws IOException { ResponseException e = expectThrows(ResponseException.class, r); Map map = responseAsMap(e.getResponse()); @@ -325,11 +335,23 @@ public void testManyConcatFromRow() throws IOException { assertManyStrings(resp, strings); } + /** + * Hits a circuit breaker by building many moderately long strings. + */ + public void testHugeManyConcatFromRow() throws IOException { + assertFoldCircuitBreaks( + () -> manyConcat( + "ROW a=9999999999999, b=99999999999999999, c=99999999999999999, d=99999999999999999, e=99999999999999999", + 5000 + ) + ); + } + /** * Fails to parse a huge huge query. */ public void testHugeHugeManyConcatFromRow() throws IOException { - assertParseFailure(() -> manyConcat("ROW a=9999, b=9999, c=9999, d=9999, e=9999", 50000)); + assertParseFailure(() -> manyConcat("ROW a=9999, b=9999, c=9999, d=9999, e=9999", 6000)); } /** @@ -387,13 +409,20 @@ public void testHugeManyRepeat() throws IOException { * Returns many moderately long strings. */ public void testManyRepeatFromRow() throws IOException { - int strings = 10000; + int strings = 300; Response resp = manyRepeat("ROW a = 99", strings); assertManyStrings(resp, strings); } /** - * Fails to parse a huge huge query. + * Hits a circuit breaker by building many moderately long strings. + */ + public void testHugeManyRepeatFromRow() throws IOException { + assertFoldCircuitBreaks(() -> manyRepeat("ROW a = 99", 400)); + } + + /** + * Fails to parse a huge, huge query. */ public void testHugeHugeManyRepeatFromRow() throws IOException { assertParseFailure(() -> manyRepeat("ROW a = 99", 100000)); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java index 00765a8c0528c..b254612a700df 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expression.java @@ -78,12 +78,20 @@ public Expression(Source source, List children) { super(source, children); } - // whether the expression can be evaluated statically (folded) or not + /** + * Whether the expression can be evaluated statically, aka "folded", or not. + */ public boolean foldable() { return false; } - public Object fold() { + /** + * Evaluate this expression statically to a constant. It is an error to call + * this if {@link #foldable} returns false. + */ + public Object fold(FoldContext ctx) { + // TODO After removing FoldContext.unbounded from non-test code examine all calls + // for places we should use instanceof Literal instead throw new QlIllegalArgumentException("Should not fold expression"); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java index 4e4338aad3704..739333ded0fde 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java @@ -107,10 +107,10 @@ public static boolean foldable(List exps) { return true; } - public static List fold(List exps) { + public static List fold(FoldContext ctx, List exps) { List folded = new ArrayList<>(exps.size()); for (Expression exp : exps) { - folded.add(exp.fold()); + folded.add(exp.fold(ctx)); } return folded; @@ -135,7 +135,7 @@ public static String name(Expression e) { /** * Is this {@linkplain Expression} guaranteed to have * only the {@code null} value. {@linkplain Expression}s that - * {@link Expression#fold()} to {@code null} may + * {@link Expression#fold} to {@code null} may * return {@code false} here, but should eventually be folded * into a {@link Literal} containing {@code null} which will return * {@code true} from here. diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FoldContext.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FoldContext.java new file mode 100644 index 0000000000000..25da44c5fd226 --- /dev/null +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FoldContext.java @@ -0,0 +1,178 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.core.expression; + +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.unit.MemorySizeValue; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.xpack.esql.core.QlClientException; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import java.util.Objects; + +/** + * Context passed to {@link Expression#fold}. This is not thread safe. + */ +public class FoldContext { + private static final long SMALL = MemorySizeValue.parseBytesSizeValueOrHeapRatio("5%", "small").getBytes(); + + /** + * {@link Expression#fold} using less than 5% of heap. Fine in tests but otherwise + * calling this is a signal that you either, shouldn't be calling {@link Expression#fold} + * at all, or should pass in a shared {@link FoldContext} made by {@code Configuration}. + */ + public static FoldContext small() { + return new FoldContext(SMALL); + } + + private final long initialAllowedBytes; + private long allowedBytes; + + public FoldContext(long allowedBytes) { + this.initialAllowedBytes = allowedBytes; + this.allowedBytes = allowedBytes; + } + + /** + * The maximum allowed bytes. {@link #allowedBytes()} will be the same as this + * for an unused context. + */ + public long initialAllowedBytes() { + return initialAllowedBytes; + } + + /** + * The remaining allowed bytes. + */ + long allowedBytes() { + return allowedBytes; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + FoldContext that = (FoldContext) o; + return initialAllowedBytes == that.initialAllowedBytes && allowedBytes == that.allowedBytes; + } + + @Override + public int hashCode() { + return Objects.hash(initialAllowedBytes, allowedBytes); + } + + @Override + public String toString() { + return "FoldContext[" + allowedBytes + '/' + initialAllowedBytes + ']'; + } + + /** + * Track an allocation. Best to call this before allocating + * if possible, but after is ok if the allocation is small. + *

+ * Note that, unlike {@link CircuitBreaker}, you don't have + * to free this allocation later. This is important because the query plan + * doesn't implement {@link Releasable} so it can't free + * consistently. But when you have to allocate big chunks of memory during + * folding and know that you are returning the memory it is kindest to + * call this with a negative number, effectively giving those bytes back. + *

+ */ + public void trackAllocation(Source source, long bytes) { + allowedBytes -= bytes; + assert allowedBytes <= initialAllowedBytes : "returned more bytes than it used"; + if (allowedBytes < 0) { + throw new FoldTooMuchMemoryException(source, bytes, initialAllowedBytes); + } + } + + /** + * Adapt this into a {@link CircuitBreaker} suitable for building bounded local + * DriverContext. This is absolutely an abuse of the {@link CircuitBreaker} contract + * and only methods used by BlockFactory are implemented. And this'll throw a + * {@link FoldTooMuchMemoryException} instead of the standard {@link CircuitBreakingException}. + * This works for the common folding implementation though. + */ + public CircuitBreaker circuitBreakerView(Source source) { + return new CircuitBreaker() { + @Override + public void circuitBreak(String fieldName, long bytesNeeded) { + throw new UnsupportedOperationException(); + } + + @Override + public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + trackAllocation(source, bytes); + } + + @Override + public void addWithoutBreaking(long bytes) { + assert bytes <= 0 : "we only expect this to be used for deallocation"; + allowedBytes -= bytes; + assert allowedBytes <= initialAllowedBytes : "returned more bytes than it used"; + } + + @Override + public long getUsed() { + /* + * This isn't expected to be used by we can implement it so we may as + * well. Maybe it'll be useful for debugging one day. + */ + return initialAllowedBytes - allowedBytes; + } + + @Override + public long getLimit() { + /* + * This isn't expected to be used by we can implement it so we may as + * well. Maybe it'll be useful for debugging one day. + */ + return initialAllowedBytes; + } + + @Override + public double getOverhead() { + return 1.0; + } + + @Override + public long getTrippedCount() { + return 0; + } + + @Override + public String getName() { + return REQUEST; + } + + @Override + public Durability getDurability() { + throw new UnsupportedOperationException(); + } + + @Override + public void setLimitAndOverhead(long limit, double overhead) { + throw new UnsupportedOperationException(); + } + }; + } + + public static class FoldTooMuchMemoryException extends QlClientException { + protected FoldTooMuchMemoryException(Source source, long bytesForExpression, long initialAllowedBytes) { + super( + "line {}:{}: Folding query used more than {}. The expression that pushed past the limit is [{}] which needed {}.", + source.source().getLineNumber(), + source.source().getColumnNumber(), + ByteSizeValue.ofBytes(initialAllowedBytes), + source.text(), + ByteSizeValue.ofBytes(bytesForExpression) + ); + } + } +} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java index 601758bca5918..233113c3fe1b8 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Foldables.java @@ -10,9 +10,9 @@ public abstract class Foldables { - public static Object valueOf(Expression e) { + public static Object valueOf(FoldContext ctx, Expression e) { if (e.foldable()) { - return e.fold(); + return e.fold(ctx); } throw new QlIllegalArgumentException("Cannot determine value for {}", e); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java index 53f559c5c82fe..afe616489d81d 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Literal.java @@ -98,7 +98,7 @@ public boolean resolved() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { return value; } @@ -138,7 +138,7 @@ public String nodeString() { * Utility method for creating a literal out of a foldable expression. * Throws an exception if the expression is not foldable. */ - public static Literal of(Expression foldable) { + public static Literal of(FoldContext ctx, Expression foldable) { if (foldable.foldable() == false) { throw new QlIllegalArgumentException("Foldable expression required for Literal creation; received unfoldable " + foldable); } @@ -147,7 +147,7 @@ public static Literal of(Expression foldable) { return (Literal) foldable; } - return new Literal(foldable.source(), foldable.fold(), foldable.dataType()); + return new Literal(foldable.source(), foldable.fold(ctx), foldable.dataType()); } public static Literal of(Expression source, Object value) { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java index b817ec17c7bda..842f3c0ddadd7 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/TypeResolutions.java @@ -133,6 +133,14 @@ public static TypeResolution isFoldable(Expression e, String operationName, Para return TypeResolution.TYPE_RESOLVED; } + /** + * Is this {@link Expression#foldable()} and not {@code null}. + * + * @deprecated instead of calling this, check for a {@link Literal} containing + * {@code null}. Foldable expressions will be folded by other rules, + * eventually, to a {@link Literal}. + */ + @Deprecated public static TypeResolution isNotNullAndFoldable(Expression e, String operationName, ParamOrdinal paramOrd) { TypeResolution resolution = isFoldable(e, operationName, paramOrd); @@ -140,7 +148,7 @@ public static TypeResolution isNotNullAndFoldable(Expression e, String operation return resolution; } - if (e.dataType() == DataType.NULL || e.fold() == null) { + if (e.dataType() == DataType.NULL || e.fold(FoldContext.small()) == null) { resolution = new TypeResolution( format( null, diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java index 8704a42ed33e2..36517b1be9ce7 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/function/scalar/UnaryScalarFunction.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; @@ -53,5 +54,5 @@ public boolean foldable() { } @Override - public abstract Object fold(); + public abstract Object fold(FoldContext ctx); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java index be5caedacd50a..bf5549b31e5fa 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/BinaryPredicate.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.core.expression.predicate; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -29,8 +30,8 @@ protected BinaryPredicate(Source source, Expression left, Expression right, F fu @SuppressWarnings("unchecked") @Override - public R fold() { - return function().apply((T) left().fold(), (U) right().fold()); + public R fold(FoldContext ctx) { + return function().apply((T) left().fold(ctx), (U) right().fold(ctx)); } @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java index 5de09f40437c7..a4e4685f764e8 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/Range.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; @@ -99,23 +100,24 @@ public boolean foldable() { } // We cannot fold the bounds here; but if they're already literals, we can check if the range is always empty. - if (lower() instanceof Literal && upper() instanceof Literal) { - return areBoundariesInvalid(); + if (lower() instanceof Literal l && upper() instanceof Literal u) { + return areBoundariesInvalid(l.value(), u.value()); } } - return false; } @Override - public Object fold() { - if (areBoundariesInvalid()) { + public Object fold(FoldContext ctx) { + Object lowerValue = lower.fold(ctx); + Object upperValue = upper.fold(ctx); + if (areBoundariesInvalid(lowerValue, upperValue)) { return Boolean.FALSE; } - Object val = value.fold(); - Integer lowerCompare = BinaryComparison.compare(lower.fold(), val); - Integer upperCompare = BinaryComparison.compare(val, upper().fold()); + Object val = value.fold(ctx); + Integer lowerCompare = BinaryComparison.compare(lower.fold(ctx), val); + Integer upperCompare = BinaryComparison.compare(val, upper().fold(ctx)); boolean lowerComparsion = lowerCompare == null ? false : (includeLower ? lowerCompare <= 0 : lowerCompare < 0); boolean upperComparsion = upperCompare == null ? false : (includeUpper ? upperCompare <= 0 : upperCompare < 0); return lowerComparsion && upperComparsion; @@ -125,9 +127,7 @@ public Object fold() { * Check whether the boundaries are invalid ( upper < lower) or not. * If they are, the value does not have to be evaluated. */ - protected boolean areBoundariesInvalid() { - Object lowerValue = lower.fold(); - Object upperValue = upper.fold(); + protected boolean areBoundariesInvalid(Object lowerValue, Object upperValue) { if (DataType.isDateTime(value.dataType()) || DataType.isDateTime(lower.dataType()) || DataType.isDateTime(upper.dataType())) { try { if (upperValue instanceof String upperString) { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java index c4983b49a6bc8..218f61856accc 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/logical/Not.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -56,8 +57,8 @@ protected TypeResolution resolveType() { } @Override - public Object fold() { - return apply(field().fold()); + public Object fold(FoldContext ctx) { + return apply(field().fold(ctx)); } private static Boolean apply(Object input) { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java index 9879a1f5ffc29..f5542ff7c3de5 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNotNull.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -49,8 +50,8 @@ protected IsNotNull replaceChild(Expression newChild) { } @Override - public Object fold() { - return field().fold() != null && DataType.isNull(field().dataType()) == false; + public Object fold(FoldContext ctx) { + return DataType.isNull(field().dataType()) == false && field().fold(ctx) != null; } @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java index d88945045b03e..bb85791a9f85e 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/nulls/IsNull.java @@ -9,6 +9,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; @@ -45,8 +46,8 @@ protected IsNull replaceChild(Expression newChild) { } @Override - public Object fold() { - return field().fold() == null || DataType.isNull(field().dataType()); + public Object fold(FoldContext ctx) { + return DataType.isNull(field().dataType()) || field().fold(ctx) == null; } @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java index 9a8a14f320cd6..b0e79704f5fda 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/operator/arithmetic/Neg.java @@ -8,6 +8,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -53,8 +54,8 @@ protected TypeResolution resolveType() { } @Override - public Object fold() { - return Arithmetics.negate((Number) field().fold()); + public Object fold(FoldContext ctx) { + return Arithmetics.negate((Number) field().fold(ctx)); } @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java index 0f9116ade5a31..a4a0a6217161e 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/predicate/regex/RegexMatch.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.core.expression.predicate.regex; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.function.scalar.UnaryScalarFunction; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -62,7 +63,7 @@ public boolean foldable() { } @Override - public Boolean fold() { + public Boolean fold(FoldContext ctx) { throw new UnsupportedOperationException(); } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/FoldContextTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/FoldContextTests.java new file mode 100644 index 0000000000000..2080f4007777c --- /dev/null +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/FoldContextTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.core.expression; + +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.EqualsHashCodeTestUtils; +import org.elasticsearch.xpack.esql.core.tree.Source; + +import static org.hamcrest.Matchers.equalTo; + +public class FoldContextTests extends ESTestCase { + public void testEq() { + EqualsHashCodeTestUtils.checkEqualsAndHashCode(randomFoldContext(), this::copy, this::mutate); + } + + private FoldContext randomFoldContext() { + FoldContext ctx = new FoldContext(randomNonNegativeLong()); + if (randomBoolean()) { + ctx.trackAllocation(Source.EMPTY, randomLongBetween(0, ctx.initialAllowedBytes())); + } + return ctx; + } + + private FoldContext copy(FoldContext ctx) { + FoldContext copy = new FoldContext(ctx.initialAllowedBytes()); + copy.trackAllocation(Source.EMPTY, ctx.initialAllowedBytes() - ctx.allowedBytes()); + return copy; + } + + private FoldContext mutate(FoldContext ctx) { + if (randomBoolean()) { + FoldContext differentInitial = new FoldContext(ctx.initialAllowedBytes() + 1); + differentInitial.trackAllocation(Source.EMPTY, differentInitial.initialAllowedBytes() - ctx.allowedBytes()); + assertThat(differentInitial.allowedBytes(), equalTo(ctx.allowedBytes())); + return differentInitial; + } else { + FoldContext differentAllowed = new FoldContext(ctx.initialAllowedBytes()); + long allowed = randomValueOtherThan(ctx.allowedBytes(), () -> randomLongBetween(0, ctx.initialAllowedBytes())); + differentAllowed.trackAllocation(Source.EMPTY, ctx.initialAllowedBytes() - allowed); + assertThat(differentAllowed.allowedBytes(), equalTo(allowed)); + return differentAllowed; + } + } + + public void testTrackAllocation() { + FoldContext ctx = new FoldContext(10); + ctx.trackAllocation(Source.synthetic("shouldn't break"), 10); + Exception e = expectThrows( + FoldContext.FoldTooMuchMemoryException.class, + () -> ctx.trackAllocation(Source.synthetic("should break"), 1) + ); + assertThat( + e.getMessage(), + equalTo( + "line -1:-1: Folding query used more than 10b. " + + "The expression that pushed past the limit is [should break] which needed 1b." + ) + ); + } + + public void testCircuitBreakerViewBreaking() { + FoldContext ctx = new FoldContext(10); + ctx.circuitBreakerView(Source.synthetic("shouldn't break")).addEstimateBytesAndMaybeBreak(10, "test"); + Exception e = expectThrows( + FoldContext.FoldTooMuchMemoryException.class, + () -> ctx.circuitBreakerView(Source.synthetic("should break")).addEstimateBytesAndMaybeBreak(1, "test") + ); + assertThat( + e.getMessage(), + equalTo( + "line -1:-1: Folding query used more than 10b. " + + "The expression that pushed past the limit is [should break] which needed 1b." + ) + ); + } + + public void testCircuitBreakerViewWithoutBreaking() { + FoldContext ctx = new FoldContext(10); + CircuitBreaker view = ctx.circuitBreakerView(Source.synthetic("shouldn't break")); + view.addEstimateBytesAndMaybeBreak(10, "test"); + view.addWithoutBreaking(-1); + assertThat(view.getUsed(), equalTo(9L)); + } + + public void testToString() { + // Random looking numbers are indeed random. Just so we have consistent numbers to assert on in toString. + FoldContext ctx = new FoldContext(123); + ctx.trackAllocation(Source.EMPTY, 22); + assertThat(ctx.toString(), equalTo("FoldContext[101/123]")); + } +} diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java index ed4c6282368ca..cd15ed5a94cfc 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/predicate/RangeTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.core.expression.predicate; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.DateUtils; @@ -211,7 +212,11 @@ public void testAreBoundariesInvalid() { (Boolean) test[7], ZoneId.systemDefault() ); - assertEquals("failed on test " + i + ": " + Arrays.toString(test), test[8], range.areBoundariesInvalid()); + assertEquals( + "failed on test " + i + ": " + Arrays.toString(test), + test[8], + range.areBoundariesInvalid(range.lower().fold(FoldContext.small()), range.upper().fold(FoldContext.small())) + ); } } @@ -226,5 +231,4 @@ private static DataType randomNumericType() { private static DataType randomTextType() { return randomFrom(KEYWORD, TEXT); } - } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java index d323174d2d3d9..91b0564a5b404 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRulesTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; @@ -104,7 +105,7 @@ public void testFoldExcludingRangeToFalse() { Range r = rangeOf(fa, SIX, false, FIVE, true); assertTrue(r.foldable()); - assertEquals(Boolean.FALSE, r.fold()); + assertEquals(Boolean.FALSE, r.fold(FoldContext.small())); } // 6 < a <= 5.5 -> FALSE @@ -113,7 +114,7 @@ public void testFoldExcludingRangeWithDifferentTypesToFalse() { Range r = rangeOf(fa, SIX, false, L(5.5d), true); assertTrue(r.foldable()); - assertEquals(Boolean.FALSE, r.fold()); + assertEquals(Boolean.FALSE, r.fold(FoldContext.small())); } // Conjunction diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index 66fd7d3ee5eb5..7e25fb29fdb78 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; @@ -61,6 +62,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; import org.elasticsearch.xpack.esql.index.EsIndex; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.parser.QueryParam; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; @@ -350,6 +352,10 @@ public String toString() { public static final Configuration TEST_CFG = configuration(new QueryPragmas(Settings.EMPTY)); + public static LogicalOptimizerContext unboundLogicalOptimizerContext() { + return new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small()); + } + public static final Verifier TEST_VERIFIER = new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L)); public static final QueryBuilderResolver MOCK_QUERY_BUILDER_RESOLVER = new MockQueryBuilderResolver(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 3d1bfdfd0ef42..a11b511cb83b7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; @@ -325,7 +326,7 @@ protected LogicalPlan rule(Enrich plan, AnalyzerContext context) { // the policy does not exist return plan; } - final String policyName = (String) plan.policyName().fold(); + final String policyName = (String) plan.policyName().fold(FoldContext.small() /* TODO remove me */); final var resolved = context.enrichResolution().getResolvedPolicy(policyName, plan.mode()); if (resolved != null) { var policy = new EnrichPolicy(resolved.matchType(), null, List.of(), resolved.matchField(), resolved.enrichFields()); @@ -1279,16 +1280,16 @@ private static boolean supportsStringImplicitCasting(DataType type) { private static UnresolvedAttribute unresolvedAttribute(Expression value, String type, Exception e) { String message = format( "Cannot convert string [{}] to [{}], error [{}]", - value.fold(), + value.fold(FoldContext.small() /* TODO remove me */), type, (e instanceof ParsingException pe) ? pe.getErrorMessage() : e.getMessage() ); - return new UnresolvedAttribute(value.source(), String.valueOf(value.fold()), message); + return new UnresolvedAttribute(value.source(), String.valueOf(value.fold(FoldContext.small() /* TODO remove me */)), message); } private static Expression castStringLiteralToTemporalAmount(Expression from) { try { - TemporalAmount result = maybeParseTemporalAmount(from.fold().toString().strip()); + TemporalAmount result = maybeParseTemporalAmount(from.fold(FoldContext.small() /* TODO remove me */).toString().strip()); if (result == null) { return from; } @@ -1304,7 +1305,11 @@ private static Expression castStringLiteral(Expression from, DataType target) { try { return isTemporalAmount(target) ? castStringLiteralToTemporalAmount(from) - : new Literal(from.source(), EsqlDataTypeConverter.convert(from.fold(), target), target); + : new Literal( + from.source(), + EsqlDataTypeConverter.convert(from.fold(FoldContext.small() /* TODO remove me */), target), + target + ); } catch (Exception e) { return unresolvedAttribute(from, target.toString(), e); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java index 9a2e9398f52fd..b9c2b92ea72dd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/EvalMapper.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; @@ -50,13 +51,23 @@ public final class EvalMapper { private EvalMapper() {} @SuppressWarnings({ "rawtypes", "unchecked" }) - public static ExpressionEvaluator.Factory toEvaluator(Expression exp, Layout layout) { + public static ExpressionEvaluator.Factory toEvaluator(FoldContext foldCtx, Expression exp, Layout layout) { if (exp instanceof EvaluatorMapper m) { - return m.toEvaluator(e -> toEvaluator(e, layout)); + return m.toEvaluator(new EvaluatorMapper.ToEvaluator() { + @Override + public ExpressionEvaluator.Factory apply(Expression expression) { + return toEvaluator(foldCtx, expression, layout); + } + + @Override + public FoldContext foldCtx() { + return foldCtx; + } + }); } for (ExpressionMapper em : MAPPERS) { if (em.typeToken.isInstance(exp)) { - return em.map(exp, layout); + return em.map(foldCtx, exp, layout); } } throw new QlIllegalArgumentException("Unsupported expression [{}]", exp); @@ -64,9 +75,9 @@ public static ExpressionEvaluator.Factory toEvaluator(Expression exp, Layout lay static class BooleanLogic extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(BinaryLogic bc, Layout layout) { - var leftEval = toEvaluator(bc.left(), layout); - var rightEval = toEvaluator(bc.right(), layout); + public ExpressionEvaluator.Factory map(FoldContext foldCtx, BinaryLogic bc, Layout layout) { + var leftEval = toEvaluator(foldCtx, bc.left(), layout); + var rightEval = toEvaluator(foldCtx, bc.right(), layout); /** * Evaluator for the three-valued boolean expressions. * We can't generate these with the {@link Evaluator} annotation because that @@ -142,8 +153,8 @@ public void close() { static class Nots extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(Not not, Layout layout) { - var expEval = toEvaluator(not.field(), layout); + public ExpressionEvaluator.Factory map(FoldContext foldCtx, Not not, Layout layout) { + var expEval = toEvaluator(foldCtx, not.field(), layout); return dvrCtx -> new org.elasticsearch.xpack.esql.evaluator.predicate.operator.logical.NotEvaluator( not.source(), expEval.get(dvrCtx), @@ -154,7 +165,7 @@ public ExpressionEvaluator.Factory map(Not not, Layout layout) { static class Attributes extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(Attribute attr, Layout layout) { + public ExpressionEvaluator.Factory map(FoldContext foldCtx, Attribute attr, Layout layout) { record Attribute(int channel) implements ExpressionEvaluator { @Override public Block eval(Page page) { @@ -189,7 +200,7 @@ public boolean eagerEvalSafeInLazy() { static class Literals extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(Literal lit, Layout layout) { + public ExpressionEvaluator.Factory map(FoldContext foldCtx, Literal lit, Layout layout) { record LiteralsEvaluator(DriverContext context, Literal lit) implements ExpressionEvaluator { @Override public Block eval(Page page) { @@ -246,8 +257,8 @@ private static Block block(Literal lit, BlockFactory blockFactory, int positions static class IsNulls extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(IsNull isNull, Layout layout) { - var field = toEvaluator(isNull.field(), layout); + public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNull isNull, Layout layout) { + var field = toEvaluator(foldCtx, isNull.field(), layout); return new IsNullEvaluatorFactory(field); } @@ -294,8 +305,8 @@ public String toString() { static class IsNotNulls extends ExpressionMapper { @Override - public ExpressionEvaluator.Factory map(IsNotNull isNotNull, Layout layout) { - return new IsNotNullEvaluatorFactory(toEvaluator(isNotNull.field(), layout)); + public ExpressionEvaluator.Factory map(FoldContext foldCtx, IsNotNull isNotNull, Layout layout) { + return new IsNotNullEvaluatorFactory(toEvaluator(foldCtx, isNotNull.field(), layout)); } record IsNotNullEvaluatorFactory(EvalOperator.ExpressionEvaluator.Factory field) implements ExpressionEvaluator.Factory { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java index d8692faef5290..5a8b3d32e7db0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapper.java @@ -7,11 +7,20 @@ package org.elasticsearch.xpack.esql.evaluator.mapper; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.indices.breaker.AllCircuitBreakerStats; +import org.elasticsearch.indices.breaker.CircuitBreakerService; +import org.elasticsearch.indices.breaker.CircuitBreakerStats; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.evaluator.EvalMapper; import org.elasticsearch.xpack.esql.planner.Layout; import static org.elasticsearch.compute.data.BlockUtils.fromArrayRow; @@ -23,9 +32,12 @@ public interface EvaluatorMapper { interface ToEvaluator { ExpressionEvaluator.Factory apply(Expression expression); + + FoldContext foldCtx(); } /** + * Convert this into an {@link ExpressionEvaluator}. *

* Note for implementors: * If you are implementing this function, you should call the passed-in @@ -35,8 +47,8 @@ interface ToEvaluator { *

* Note for Callers: * If you are attempting to call this method, and you have an - * {@link Expression} and a {@link org.elasticsearch.xpack.esql.planner.Layout}, - * you likely want to call {@link org.elasticsearch.xpack.esql.evaluator.EvalMapper#toEvaluator(Expression, Layout)} + * {@link Expression} and a {@link Layout}, + * you likely want to call {@link EvalMapper#toEvaluator} * instead. On the other hand, if you already have something that * looks like the parameter for this method, you should call this method * with that function. @@ -56,19 +68,89 @@ interface ToEvaluator { /** * Fold using {@link #toEvaluator} so you don't need a "by hand" - * implementation of fold. The evaluator that it makes is "funny" - * in that it'll always call {@link Expression#fold}, but that's - * good enough. + * implementation of {@link Expression#fold}. */ - default Object fold() { - return toJavaObject(toEvaluator(e -> driverContext -> new ExpressionEvaluator() { + default Object fold(Source source, FoldContext ctx) { + /* + * OK! So! We're going to build a bunch of *stuff* that so that we can + * call toEvaluator and use it without standing up an entire compute + * engine. + * + * Step 1 is creation of a `toEvaluator` which we'll soon use to turn + * the *children* of this Expression into ExpressionEvaluators. They + * have to be foldable or else we wouldn't have ended up here. So! + * We just call `fold` on them and turn the result of that into a + * Block. + * + * If the tree of expressions is pretty deep that `fold` call will + * likely end up being implemented by calling this method for the + * child. That's fine. Recursion is how you process trees. + */ + ToEvaluator foldChildren = new ToEvaluator() { @Override - public Block eval(Page page) { - return fromArrayRow(driverContext.blockFactory(), e.fold())[0]; + public ExpressionEvaluator.Factory apply(Expression expression) { + return driverContext -> new ExpressionEvaluator() { + @Override + public Block eval(Page page) { + return fromArrayRow(driverContext.blockFactory(), expression.fold(ctx))[0]; + } + + @Override + public void close() {} + }; } @Override - public void close() {} - }).get(DriverContext.getLocalDriver()).eval(new Page(1)), 0); + public FoldContext foldCtx() { + return ctx; + } + }; + + /* + * Step 2 is to create a DriverContext that we can pass to the above. + * This DriverContext is mostly about delegating to the FoldContext. + * That'll cause us to break if we attempt to allocate a huge amount + * of memory. Neat. + * + * Specifically, we make a CircuitBreaker view of the FoldContext, then + * we wrap it in a CircuitBreakerService so we can feed it to a BigArray + * so we can feed *that* into a DriverContext. It's a bit hacky, but + * that's what's going on here. + */ + CircuitBreaker breaker = ctx.circuitBreakerView(source); + BigArrays bigArrays = new BigArrays(null, new CircuitBreakerService() { + @Override + public CircuitBreaker getBreaker(String name) { + if (name.equals(CircuitBreaker.REQUEST) == false) { + throw new UnsupportedOperationException(); + } + return breaker; + } + + @Override + public AllCircuitBreakerStats stats() { + throw new UnsupportedOperationException(); + } + + @Override + public CircuitBreakerStats stats(String name) { + throw new UnsupportedOperationException(); + } + }, CircuitBreaker.REQUEST).withCircuitBreaking(); + DriverContext driverCtx = new DriverContext(bigArrays, new BlockFactory(breaker, bigArrays)); + + /* + * Finally we can call toEvaluator on ourselves! It'll fold our children, + * convert the result into Blocks, and then we'll run that with the memory + * breaking DriverContext. + * + * Then, finally finally, we turn the result into a java object to be compatible + * with the signature of `fold`. + */ + Block block = toEvaluator(foldChildren).get(driverCtx).eval(new Page(1)); + if (block.getPositionCount() != 1) { + throw new IllegalStateException("generated odd block from fold [" + block + "]"); + } + return toJavaObject(block, 0); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java index 5cd830058573f..5a76080e7995c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/evaluator/mapper/ExpressionMapper.java @@ -9,6 +9,7 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; import org.elasticsearch.xpack.esql.planner.Layout; @@ -19,5 +20,5 @@ public ExpressionMapper() { typeToken = ReflectionUtils.detectSuperTypeForRuleLike(getClass()); } - public abstract ExpressionEvaluator.Factory map(E expression, Layout layout); + public abstract ExpressionEvaluator.Factory map(FoldContext foldCtx, E expression, Layout layout); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java index 974f029eab2ef..94913581f696d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.action.EsqlQueryRequest; import org.elasticsearch.xpack.esql.analysis.PreAnalyzer; import org.elasticsearch.xpack.esql.analysis.Verifier; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; @@ -56,6 +57,7 @@ public void esql( EsqlQueryRequest request, String sessionId, Configuration cfg, + FoldContext foldContext, EnrichPolicyResolver enrichPolicyResolver, EsqlExecutionInfo executionInfo, IndicesExpressionGrouper indicesExpressionGrouper, @@ -71,7 +73,7 @@ public void esql( enrichPolicyResolver, preAnalyzer, functionRegistry, - new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg)), + new LogicalPlanOptimizer(new LogicalOptimizerContext(cfg, foldContext)), mapper, verifier, planningMetrics, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java index 011fcaccf7fe4..8aa7f697489c6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AggregateFunction.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.Function; @@ -92,7 +93,8 @@ public List parameters() { } public boolean hasFilter() { - return filter != null && (filter.foldable() == false || Boolean.TRUE.equals(filter.fold()) == false); + return filter != null + && (filter.foldable() == false || Boolean.TRUE.equals(filter.fold(FoldContext.small() /* TODO remove me */)) == false); } public Expression filter() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java index f80333d83d6cb..c738dfc8ff591 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/CountDistinct.java @@ -19,6 +19,7 @@ import org.elasticsearch.compute.aggregation.CountDistinctLongAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -209,7 +210,9 @@ protected TypeResolution resolveType() { @Override public AggregatorFunctionSupplier supplier(List inputChannels) { DataType type = field().dataType(); - int precision = this.precision == null ? DEFAULT_PRECISION : ((Number) this.precision.fold()).intValue(); + int precision = this.precision == null + ? DEFAULT_PRECISION + : ((Number) this.precision.fold(FoldContext.small() /* TODO remove me */)).intValue(); if (SUPPLIERS.containsKey(type) == false) { // If the type checking did its job, this should never happen throw EsqlIllegalArgumentException.illegalDataType(type); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java index 0d57267da1e29..8c943c991d501 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Percentile.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.aggregation.PercentileIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.PercentileLongAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -170,7 +171,7 @@ protected AggregatorFunctionSupplier doubleSupplier(List inputChannels) } private int percentileValue() { - return ((Number) percentile.fold()).intValue(); + return ((Number) percentile.fold(FoldContext.small() /* TODO remove me */)).intValue(); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java index 87ac9b77a6826..85ae65b6c5dc3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Rate.java @@ -18,6 +18,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -156,7 +157,7 @@ long unitInMillis() { } final Object foldValue; try { - foldValue = unit.fold(); + foldValue = unit.fold(FoldContext.small() /* TODO remove me */); } catch (Exception e) { throw new IllegalArgumentException("function [" + sourceText() + "] has invalid unit [" + unit.sourceText() + "]"); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java index 40777b4d78dc2..9be8c94266ee8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Top.java @@ -21,6 +21,7 @@ import org.elasticsearch.compute.aggregation.TopLongAggregatorFunctionSupplier; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -115,11 +116,11 @@ Expression orderField() { } private int limitValue() { - return (int) limitField().fold(); + return (int) limitField().fold(FoldContext.small() /* TODO remove me */); } private String orderRawValue() { - return BytesRefs.toString(orderField().fold()); + return BytesRefs.toString(orderField().fold(FoldContext.small() /* TODO remove me */)); } private boolean orderValue() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java index 56d034a2eae1d..bab65653ba576 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -114,9 +115,15 @@ protected Expression.TypeResolution resolveType() { return resolution; } - if (weight.dataType() == DataType.NULL - || (weight.foldable() && (weight.fold() == null || weight.fold().equals(0) || weight.fold().equals(0.0)))) { - return new TypeResolution(format(null, invalidWeightError, SECOND, sourceText(), weight.foldable() ? weight.fold() : null)); + if (weight.dataType() == DataType.NULL) { + return new TypeResolution(format(null, invalidWeightError, SECOND, sourceText(), null)); + } + if (weight.foldable() == false) { + return TypeResolution.TYPE_RESOLVED; + } + Object weightVal = weight.fold(FoldContext.small()/* TODO remove me*/); + if (weightVal == null || weightVal.equals(0) || weightVal.equals(0.0)) { + return new TypeResolution(format(null, invalidWeightError, SECOND, sourceText(), weightVal)); } return TypeResolution.TYPE_RESOLVED; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java index 07c4bb282ba71..4da7c01139c24 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/FullTextFunction.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.TranslationAware; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; @@ -111,7 +112,7 @@ public Expression query() { * @return query expression as an object */ public Object queryAsObject() { - Object queryAsObject = query().fold(); + Object queryAsObject = query().fold(FoldContext.small() /* TODO remove me */); if (queryAsObject instanceof BytesRef bytesRef) { return bytesRef.utf8ToString(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java index a12d82d6e4267..9275176cd41b1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/fulltext/Match.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.planner.ExpressionTranslator; import org.elasticsearch.xpack.esql.core.querydsl.query.QueryStringQuery; @@ -215,7 +216,7 @@ public void postLogicalOptimizationVerification(Failures failures) { @Override public Object queryAsObject() { - Object queryAsObject = query().fold(); + Object queryAsObject = query().fold(FoldContext.small() /* TODO remove me */); // Convert BytesRef to string for string-based values if (queryAsObject instanceof BytesRef bytesRef) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java index 113989323eff2..7a3e080f5c830 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Foldables; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; @@ -255,25 +256,25 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (field.dataType() == DataType.DATETIME || field.dataType() == DataType.DATE_NANOS) { Rounding.Prepared preparedRounding; if (buckets.dataType().isWholeNumber()) { - int b = ((Number) buckets.fold()).intValue(); - long f = foldToLong(from); - long t = foldToLong(to); + int b = ((Number) buckets.fold(toEvaluator.foldCtx())).intValue(); + long f = foldToLong(toEvaluator.foldCtx(), from); + long t = foldToLong(toEvaluator.foldCtx(), to); preparedRounding = new DateRoundingPicker(b, f, t).pickRounding().prepareForUnknown(); } else { assert DataType.isTemporalAmount(buckets.dataType()) : "Unexpected span data type [" + buckets.dataType() + "]"; - preparedRounding = DateTrunc.createRounding(buckets.fold(), DEFAULT_TZ); + preparedRounding = DateTrunc.createRounding(buckets.fold(toEvaluator.foldCtx()), DEFAULT_TZ); } return DateTrunc.evaluator(field.dataType(), source(), toEvaluator.apply(field), preparedRounding); } if (field.dataType().isNumeric()) { double roundTo; if (from != null) { - int b = ((Number) buckets.fold()).intValue(); - double f = ((Number) from.fold()).doubleValue(); - double t = ((Number) to.fold()).doubleValue(); + int b = ((Number) buckets.fold(toEvaluator.foldCtx())).intValue(); + double f = ((Number) from.fold(toEvaluator.foldCtx())).doubleValue(); + double t = ((Number) to.fold(toEvaluator.foldCtx())).doubleValue(); roundTo = pickRounding(b, f, t); } else { - roundTo = ((Number) buckets.fold()).doubleValue(); + roundTo = ((Number) buckets.fold(toEvaluator.foldCtx())).doubleValue(); } Literal rounding = new Literal(source(), roundTo, DataType.DOUBLE); @@ -416,8 +417,8 @@ public void postLogicalOptimizationVerification(Failures failures) { .add(to != null ? isFoldable(to, operation, FOURTH) : null); } - private long foldToLong(Expression e) { - Object value = Foldables.valueOf(e); + private long foldToLong(FoldContext ctx, Expression e) { + Object value = Foldables.valueOf(ctx, e); return DataType.isDateTime(e.dataType()) ? ((Number) value).longValue() : dateTimeToLong(((BytesRef) value).utf8ToString()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java index 0fee65d32ca98..fd025e5e67a7c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/GroupingFunction.java @@ -10,6 +10,7 @@ import org.elasticsearch.xpack.esql.capabilities.PostAnalysisPlanVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; @@ -28,8 +29,8 @@ protected GroupingFunction(Source source, List fields) { } @Override - public Object fold() { - return EvaluatorMapper.super.fold(); + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java index 404ce7e3900c9..85d15f82f458a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/EsqlScalarFunction.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.expression.function.scalar; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; @@ -34,7 +35,7 @@ protected EsqlScalarFunction(Source source, List fields) { } @Override - public Object fold() { - return EvaluatorMapper.super.fold(); + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java index 824f02ca7ccbb..236e625f7abe1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java @@ -23,6 +23,7 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; @@ -227,7 +228,7 @@ public boolean foldable() { if (condition.condition.foldable() == false) { return false; } - if (Boolean.TRUE.equals(condition.condition.fold())) { + if (Boolean.TRUE.equals(condition.condition.fold(FoldContext.small() /* TODO remove me - use literal true?*/))) { /* * `fold` can make four things here: * 1. `TRUE` @@ -264,7 +265,8 @@ public boolean foldable() { * And those two combine so {@code EVAL c=CASE(false, foo, b, bar, true, bort, el)} becomes * {@code EVAL c=CASE(b, bar, bort)}. */ - public Expression partiallyFold() { + public Expression partiallyFold(FoldContext ctx) { + // TODO don't throw away the results of any `fold`. That might mean looking for literal TRUE on the conditions. List newChildren = new ArrayList<>(children().size()); boolean modified = false; for (Condition condition : conditions) { @@ -274,7 +276,7 @@ public Expression partiallyFold() { continue; } modified = true; - if (Boolean.TRUE.equals(condition.condition.fold())) { + if (Boolean.TRUE.equals(condition.condition.fold(ctx))) { /* * `fold` can make four things here: * 1. `TRUE` diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java index 842e899ebdac6..57f362f86ff4c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/FoldablesConvertFunction.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.esql.capabilities.PostOptimizationVerificationAware; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -65,8 +66,8 @@ protected final Map factories() { } @Override - public final Object fold() { - return foldToTemporalAmount(field(), sourceText(), dataType()); + public final Object fold(FoldContext ctx) { + return foldToTemporalAmount(ctx, field(), sourceText(), dataType()); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java index f6a23a5d5962e..b588832aba4cb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateDiff.java @@ -232,7 +232,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (unit.foldable()) { try { - Part datePartField = Part.resolve(((BytesRef) unit.fold()).utf8ToString()); + Part datePartField = Part.resolve(((BytesRef) unit.fold(toEvaluator.foldCtx())).utf8ToString()); return new DateDiffConstantEvaluator.Factory(source(), datePartField, startTimestampEvaluator, endTimestampEvaluator); } catch (IllegalArgumentException e) { throw new InvalidArgumentException("invalid unit format for [{}]: {}", sourceText(), e.getMessage()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java index 501dfd431f106..7fc5d82441802 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtract.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -110,9 +111,9 @@ public String getWriteableName() { public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { var fieldEvaluator = toEvaluator.apply(children().get(1)); if (children().get(0).foldable()) { - ChronoField chrono = chronoField(); + ChronoField chrono = chronoField(toEvaluator.foldCtx()); if (chrono == null) { - BytesRef field = (BytesRef) children().get(0).fold(); + BytesRef field = (BytesRef) children().get(0).fold(toEvaluator.foldCtx()); throw new InvalidArgumentException("invalid date field for [{}]: {}", sourceText(), field.utf8ToString()); } return new DateExtractConstantEvaluator.Factory(source(), fieldEvaluator, chrono, configuration().zoneId()); @@ -121,14 +122,14 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { return new DateExtractEvaluator.Factory(source(), fieldEvaluator, chronoEvaluator, configuration().zoneId()); } - private ChronoField chronoField() { + private ChronoField chronoField(FoldContext ctx) { // chronoField's never checked (the return is). The foldability test is done twice and type is checked in resolveType() already. // TODO: move the slimmed down code here to toEvaluator? if (chronoField == null) { Expression field = children().get(0); try { if (field.foldable() && DataType.isString(field.dataType())) { - chronoField = (ChronoField) STRING_TO_CHRONO_FIELD.convert(field.fold()); + chronoField = (ChronoField) STRING_TO_CHRONO_FIELD.convert(field.fold(ctx)); } } catch (Exception e) { return null; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java index 920a3bb1f4a13..29648d55cadd8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateFormat.java @@ -147,7 +147,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { throw new IllegalArgumentException("unsupported data type for format [" + format.dataType() + "]"); } if (format.foldable()) { - DateFormatter formatter = toFormatter(format.fold(), configuration().locale()); + DateFormatter formatter = toFormatter(format.fold(toEvaluator.foldCtx()), configuration().locale()); return new DateFormatConstantEvaluator.Factory(source(), fieldEvaluator, formatter); } var formatEvaluator = toEvaluator.apply(format); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java index e09fabab98d0f..7c38b54ed232b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateParse.java @@ -143,7 +143,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { } if (format.foldable()) { try { - DateFormatter formatter = toFormatter(format.fold()); + DateFormatter formatter = toFormatter(format.fold(toEvaluator.foldCtx())); return new DateParseConstantEvaluator.Factory(source(), fieldEvaluator, formatter); } catch (IllegalArgumentException e) { throw new InvalidArgumentException(e, "invalid date pattern for [{}]: {}", sourceText(), e.getMessage()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java index a35b67d7ac3fd..7983c38cc4288 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateTrunc.java @@ -225,7 +225,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { } Object foldedInterval; try { - foldedInterval = interval.fold(); + foldedInterval = interval.fold(toEvaluator.foldCtx()); if (foldedInterval == null) { throw new IllegalArgumentException("Interval cannot not be null"); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java index d259fc6ae57ce..74c2da450995c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/Now.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -59,7 +60,7 @@ public String getWriteableName() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { return now; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java index 757b67b47ce72..e1eceef7ed1f5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/E.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; @@ -49,7 +50,7 @@ public String getWriteableName() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { return Math.E; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java index 90a4f1f091e91..32b7a0ab88b4e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Pi.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; @@ -49,7 +50,7 @@ public String getWriteableName() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { return Math.PI; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java index 17e5b027270d1..1a7669b7391e1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/math/Tau.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; @@ -51,7 +52,7 @@ public String getWriteableName() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { return TAU; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java index 1996744a76567..26211258e6ca6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvConcat.java @@ -17,6 +17,7 @@ import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -91,8 +92,8 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { } @Override - public Object fold() { - return EvaluatorMapper.super.fold(); + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java index 4dd447f938880..d5093964145b7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java @@ -115,7 +115,7 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua source(), toEvaluator.apply(field), ctx -> new CompensatedSum(), - (Double) p.fold() + (Double) p.fold(toEvaluator.foldCtx()) ); case NULL -> EvalOperator.CONSTANT_NULL_FACTORY; default -> throw EsqlIllegalArgumentException.illegalDataType(field.dataType()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java index f4f9679dc3704..4a04524d1b23d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSlice.java @@ -187,8 +187,8 @@ public boolean foldable() { @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (start.foldable() && end.foldable()) { - int startOffset = stringToInt(String.valueOf(start.fold())); - int endOffset = stringToInt(String.valueOf(end.fold())); + int startOffset = stringToInt(String.valueOf(start.fold(toEvaluator.foldCtx()))); + int endOffset = stringToInt(String.valueOf(end.fold(toEvaluator.foldCtx()))); checkStartEnd(startOffset, endOffset); } return switch (PlannerUtils.toElementType(field.dataType())) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java index 86538c828ece7..b68718acfcd0a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -155,12 +156,12 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua sourceText(), ASC.value(), DESC.value(), - ((BytesRef) order.fold()).utf8ToString() + ((BytesRef) order.fold(toEvaluator.foldCtx())).utf8ToString() ) ); } if (order != null && order.foldable()) { - ordering = ((BytesRef) order.fold()).utf8ToString().equalsIgnoreCase((String) ASC.value()); + ordering = ((BytesRef) order.fold(toEvaluator.foldCtx())).utf8ToString().equalsIgnoreCase((String) ASC.value()); } return switch (PlannerUtils.toElementType(field.dataType())) { @@ -238,7 +239,14 @@ public void postLogicalOptimizationVerification(Failures failures) { failures.add(isFoldable(order, operation, SECOND)); if (isValidOrder() == false) { failures.add( - Failure.fail(order, INVALID_ORDER_ERROR, sourceText(), ASC.value(), DESC.value(), ((BytesRef) order.fold()).utf8ToString()) + Failure.fail( + order, + INVALID_ORDER_ERROR, + sourceText(), + ASC.value(), + DESC.value(), + ((BytesRef) order.fold(FoldContext.small() /* TODO remove me */)).utf8ToString() + ) ); } } @@ -246,7 +254,7 @@ public void postLogicalOptimizationVerification(Failures failures) { private boolean isValidOrder() { boolean isValidOrder = true; if (order != null && order.foldable()) { - Object obj = order.fold(); + Object obj = order.fold(FoldContext.small() /* TODO remove me */); String o = null; if (obj instanceof BytesRef ob) { o = ob.utf8ToString(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java index 9189c6a7b8f70..b15d04aa792d9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialContains.java @@ -26,6 +26,7 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -214,10 +215,10 @@ protected NodeInfo info() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { try { - GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left()); - Geometry rightGeom = makeGeometryFromLiteral(right()); + GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left()); + Geometry rightGeom = makeGeometryFromLiteral(ctx, right()); Component2D[] components = asLuceneComponent2Ds(crsType(), rightGeom); return (crsType() == SpatialCrsType.GEO) ? GEO.geometryRelatesGeometries(docValueReader, components) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java index ee78f50c4d6bd..3e16fa163fcd6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialDisjoint.java @@ -23,6 +23,7 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -129,10 +130,10 @@ protected NodeInfo info() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { try { - GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left()); - Component2D component2D = asLuceneComponent2D(crsType(), right()); + GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left()); + Component2D component2D = asLuceneComponent2D(ctx, crsType(), right()); return (crsType() == SpatialCrsType.GEO) ? GEO.geometryRelatesGeometry(docValueReader, component2D) : CARTESIAN.geometryRelatesGeometry(docValueReader, component2D); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java index 1a51af8dfeeb4..dcd53075cf69c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialEvaluatorFactory.java @@ -171,7 +171,11 @@ protected static class SpatialEvaluatorWithConstantFactory extends SpatialEvalua @Override public EvalOperator.ExpressionEvaluator.Factory get(SpatialSourceSupplier s, EvaluatorMapper.ToEvaluator toEvaluator) { - return factoryCreator.apply(s.source(), toEvaluator.apply(s.left()), asLuceneComponent2D(s.crsType(), s.right())); + return factoryCreator.apply( + s.source(), + toEvaluator.apply(s.left()), + asLuceneComponent2D(toEvaluator.foldCtx(), s.crsType(), s.right()) + ); } } @@ -197,7 +201,11 @@ protected static class SpatialEvaluatorWithConstantArrayFactory extends SpatialE @Override public EvalOperator.ExpressionEvaluator.Factory get(SpatialSourceSupplier s, EvaluatorMapper.ToEvaluator toEvaluator) { - return factoryCreator.apply(s.source(), toEvaluator.apply(s.left()), asLuceneComponent2Ds(s.crsType(), s.right())); + return factoryCreator.apply( + s.source(), + toEvaluator.apply(s.left()), + asLuceneComponent2Ds(toEvaluator.foldCtx(), s.crsType(), s.right()) + ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java index 8d54e5ee443c2..601550cd173bb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialIntersects.java @@ -23,6 +23,7 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -127,10 +128,10 @@ protected NodeInfo info() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { try { - GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left()); - Component2D component2D = asLuceneComponent2D(crsType(), right()); + GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left()); + Component2D component2D = asLuceneComponent2D(ctx, crsType(), right()); return (crsType() == SpatialCrsType.GEO) ? GEO.geometryRelatesGeometry(docValueReader, component2D) : CARTESIAN.geometryRelatesGeometry(docValueReader, component2D); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java index 6ae99ea8165cd..1b06c6dfd3dd5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialRelatesUtils.java @@ -29,6 +29,7 @@ import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.lucene.spatial.GeometryDocValueWriter; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes; @@ -42,8 +43,8 @@ public class SpatialRelatesUtils { /** Converts a {@link Expression} into a {@link Component2D}. */ - static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) { - return asLuceneComponent2D(crsType, makeGeometryFromLiteral(expression)); + static Component2D asLuceneComponent2D(FoldContext ctx, BinarySpatialFunction.SpatialCrsType crsType, Expression expression) { + return asLuceneComponent2D(crsType, makeGeometryFromLiteral(ctx, expression)); } /** Converts a {@link Geometry} into a {@link Component2D}. */ @@ -66,8 +67,8 @@ static Component2D asLuceneComponent2D(BinarySpatialFunction.SpatialCrsType type * Converts a {@link Expression} at a given {@code position} into a {@link Component2D} array. * The reason for generating an array instead of a single component is for multi-shape support with ST_CONTAINS. */ - static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) { - return asLuceneComponent2Ds(crsType, makeGeometryFromLiteral(expression)); + static Component2D[] asLuceneComponent2Ds(FoldContext ctx, BinarySpatialFunction.SpatialCrsType crsType, Expression expression) { + return asLuceneComponent2Ds(crsType, makeGeometryFromLiteral(ctx, expression)); } /** @@ -90,9 +91,12 @@ static Component2D[] asLuceneComponent2Ds(BinarySpatialFunction.SpatialCrsType t } /** Converts a {@link Expression} into a {@link GeometryDocValueReader} */ - static GeometryDocValueReader asGeometryDocValueReader(BinarySpatialFunction.SpatialCrsType crsType, Expression expression) - throws IOException { - Geometry geometry = makeGeometryFromLiteral(expression); + static GeometryDocValueReader asGeometryDocValueReader( + FoldContext ctx, + BinarySpatialFunction.SpatialCrsType crsType, + Expression expression + ) throws IOException { + Geometry geometry = makeGeometryFromLiteral(ctx, expression); if (crsType == BinarySpatialFunction.SpatialCrsType.GEO) { return asGeometryDocValueReader( CoordinateEncoder.GEO, @@ -167,8 +171,8 @@ private static Geometry asGeometry(BytesRefBlock valueBlock, int position) { * This function is used in two places, when evaluating a spatial constant in the SpatialRelatesFunction, as well as when * we do lucene-pushdown of spatial functions. */ - public static Geometry makeGeometryFromLiteral(Expression expr) { - return makeGeometryFromLiteralValue(valueOf(expr), expr.dataType()); + public static Geometry makeGeometryFromLiteral(FoldContext ctx, Expression expr) { + return makeGeometryFromLiteralValue(valueOf(ctx, expr), expr.dataType()); } private static Geometry makeGeometryFromLiteralValue(Object value, DataType dataType) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java index 2005709cd37e9..9fcece1ce65bc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/SpatialWithin.java @@ -23,6 +23,7 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.lucene.spatial.GeometryDocValueReader; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -129,10 +130,10 @@ protected NodeInfo info() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { try { - GeometryDocValueReader docValueReader = asGeometryDocValueReader(crsType(), left()); - Component2D component2D = asLuceneComponent2D(crsType(), right()); + GeometryDocValueReader docValueReader = asGeometryDocValueReader(ctx, crsType(), left()); + Component2D component2D = asLuceneComponent2D(ctx, crsType(), right()); return (crsType() == SpatialCrsType.GEO) ? GEO.geometryRelatesGeometry(docValueReader, component2D) : CARTESIAN.geometryRelatesGeometry(docValueReader, component2D); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java index 3cf042a2db828..f0c25e3289cc1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/spatial/StDistance.java @@ -23,6 +23,7 @@ import org.elasticsearch.lucene.spatial.CoordinateEncoder; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -280,18 +281,18 @@ protected NodeInfo info() { } @Override - public Object fold() { - var leftGeom = makeGeometryFromLiteral(left()); - var rightGeom = makeGeometryFromLiteral(right()); + public Object fold(FoldContext ctx) { + var leftGeom = makeGeometryFromLiteral(ctx, left()); + var rightGeom = makeGeometryFromLiteral(ctx, right()); return (crsType() == SpatialCrsType.GEO) ? GEO.distance(leftGeom, rightGeom) : CARTESIAN.distance(leftGeom, rightGeom); } @Override public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (right().foldable()) { - return toEvaluator(toEvaluator, left(), makeGeometryFromLiteral(right()), leftDocValues); + return toEvaluator(toEvaluator, left(), makeGeometryFromLiteral(toEvaluator.foldCtx(), right()), leftDocValues); } else if (left().foldable()) { - return toEvaluator(toEvaluator, right(), makeGeometryFromLiteral(left()), rightDocValues); + return toEvaluator(toEvaluator, right(), makeGeometryFromLiteral(toEvaluator.foldCtx(), left()), rightDocValues); } else { EvalOperator.ExpressionEvaluator.Factory leftE = toEvaluator.apply(left()); EvalOperator.ExpressionEvaluator.Factory rightE = toEvaluator.apply(right()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java index 52d33c0fc9d3d..be0a7b2fe27b2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Hash.java @@ -146,7 +146,7 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua if (algorithm.foldable()) { try { // hash function is created here in order to validate the algorithm is valid before evaluator is created - var hf = HashFunction.create((BytesRef) algorithm.fold()); + var hf = HashFunction.create((BytesRef) algorithm.fold(toEvaluator.foldCtx())); return new HashConstantEvaluator.Factory( source(), context -> new BreakingBytesRefBuilder(context.breaker(), "hash"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java index 996c90a8e40bc..fb0aac0c85b38 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLike.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -94,8 +95,8 @@ protected TypeResolution resolveType() { } @Override - public Boolean fold() { - return (Boolean) EvaluatorMapper.super.fold(); + public Boolean fold(FoldContext ctx) { + return (Boolean) EvaluatorMapper.super.fold(source(), ctx); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java index e91f03de3dd7e..363991d1556f1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Repeat.java @@ -151,7 +151,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { ExpressionEvaluator.Factory strExpr = toEvaluator.apply(str); if (number.foldable()) { - int num = (int) number.fold(); + int num = (int) number.fold(toEvaluator.foldCtx()); if (num < 0) { throw new IllegalArgumentException("Number parameter cannot be negative, found [" + number + "]"); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java index 4fa191244cb42..4b963b794aef0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Replace.java @@ -152,7 +152,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (regex.foldable() && regex.dataType() == DataType.KEYWORD) { Pattern regexPattern; try { - regexPattern = Pattern.compile(((BytesRef) regex.fold()).utf8ToString()); + regexPattern = Pattern.compile(((BytesRef) regex.fold(toEvaluator.foldCtx())).utf8ToString()); } catch (PatternSyntaxException pse) { // TODO this is not right (inconsistent). See also https://github.com/elastic/elasticsearch/issues/100038 // this should generate a header warning and return null (as do the rest of this functionality in evaluators), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java index 3b9a466966911..e46c0a730431d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Space.java @@ -113,7 +113,7 @@ protected NodeInfo info() { @Override public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (field.foldable()) { - Object folded = field.fold(); + Object folded = field.fold(toEvaluator.foldCtx()); if (folded instanceof Integer num) { checkNumber(num); return toEvaluator.apply(new Literal(source(), " ".repeat(num), KEYWORD)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java index 24762122f755b..d0c1035978ff3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/Split.java @@ -17,6 +17,7 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.function.scalar.BinaryScalarFunction; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -102,8 +103,8 @@ public boolean foldable() { } @Override - public Object fold() { - return EvaluatorMapper.super.fold(); + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); } @Evaluator(extraName = "SingleByte") @@ -163,7 +164,7 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { if (right().foldable() == false) { return new SplitVariableEvaluator.Factory(source(), str, toEvaluator.apply(right()), context -> new BytesRef()); } - BytesRef delim = (BytesRef) right().fold(); + BytesRef delim = (BytesRef) right().fold(toEvaluator.foldCtx()); checkDelimiter(delim); return new SplitSingleByteEvaluator.Factory(source(), str, delim.bytes[delim.offset], context -> new BytesRef()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java index d2edb0f92e8f2..65455c708cc9b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLike.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -100,8 +101,8 @@ protected TypeResolution resolveType() { } @Override - public Boolean fold() { - return (Boolean) EvaluatorMapper.super.fold(); + public Boolean fold(FoldContext ctx) { + return (Boolean) EvaluatorMapper.super.fold(source(), ctx); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java index 8ddb26f5bb786..a09aed469b6ab 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/util/Delay.java @@ -14,6 +14,7 @@ import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -84,15 +85,15 @@ public boolean foldable() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { return null; } - private long msValue() { + private long msValue(FoldContext ctx) { if (field().foldable() == false) { throw new IllegalArgumentException("function [" + sourceText() + "] has invalid argument [" + field().sourceText() + "]"); } - var ms = field().fold(); + var ms = field().fold(ctx); if (ms instanceof Duration duration) { return duration.toMillis(); } @@ -101,7 +102,7 @@ private long msValue() { @Override public ExpressionEvaluator.Factory toEvaluator(EvaluatorMapper.ToEvaluator toEvaluator) { - return new DelayEvaluator.Factory(source(), msValue()); + return new DelayEvaluator.Factory(source(), msValue(toEvaluator.foldCtx())); } @Evaluator diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java index 8bb166fac60bb..424c080c905e3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DateTimeArithmeticOperation.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.ExceptionUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -111,7 +112,7 @@ protected TypeResolution checkCompatibility() { /** * Override this to allow processing literals of type {@link DataType#DATE_PERIOD} when folding constants. - * Used in {@link DateTimeArithmeticOperation#fold()}. + * Used in {@link DateTimeArithmeticOperation#fold}. * @param left the left period * @param right the right period * @return the result of the evaluation @@ -120,7 +121,7 @@ protected TypeResolution checkCompatibility() { /** * Override this to allow processing literals of type {@link DataType#TIME_DURATION} when folding constants. - * Used in {@link DateTimeArithmeticOperation#fold()}. + * Used in {@link DateTimeArithmeticOperation#fold}. * @param left the left duration * @param right the right duration * @return the result of the evaluation @@ -128,13 +129,13 @@ protected TypeResolution checkCompatibility() { abstract Duration fold(Duration left, Duration right); @Override - public final Object fold() { + public final Object fold(FoldContext ctx) { DataType leftDataType = left().dataType(); DataType rightDataType = right().dataType(); if (leftDataType == DATE_PERIOD && rightDataType == DATE_PERIOD) { // Both left and right expressions are temporal amounts; we can assume they are both foldable. - var l = left().fold(); - var r = right().fold(); + var l = left().fold(ctx); + var r = right().fold(ctx); if (l instanceof Collection || r instanceof Collection) { return null; } @@ -148,8 +149,8 @@ public final Object fold() { } if (leftDataType == TIME_DURATION && rightDataType == TIME_DURATION) { // Both left and right expressions are temporal amounts; we can assume they are both foldable. - Duration l = (Duration) left().fold(); - Duration r = (Duration) right().fold(); + Duration l = (Duration) left().fold(ctx); + Duration r = (Duration) right().fold(ctx); try { return fold(l, r); } catch (ArithmeticException e) { @@ -161,7 +162,7 @@ public final Object fold() { if (isNull(leftDataType) || isNull(rightDataType)) { return null; } - return super.fold(); + return super.fold(ctx); } @Override @@ -178,7 +179,11 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { temporalAmountArgument = left(); } - return millisEvaluator.apply(source(), toEvaluator.apply(datetimeArgument), (TemporalAmount) temporalAmountArgument.fold()); + return millisEvaluator.apply( + source(), + toEvaluator.apply(datetimeArgument), + (TemporalAmount) temporalAmountArgument.fold(toEvaluator.foldCtx()) + ); } else if (dataType() == DATE_NANOS) { // One of the arguments has to be a date_nanos and the other a temporal amount. Expression dateNanosArgument; @@ -191,7 +196,11 @@ public ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { temporalAmountArgument = left(); } - return nanosEvaluator.apply(source(), toEvaluator.apply(dateNanosArgument), (TemporalAmount) temporalAmountArgument.fold()); + return nanosEvaluator.apply( + source(), + toEvaluator.apply(dateNanosArgument), + (TemporalAmount) temporalAmountArgument.fold(toEvaluator.foldCtx()) + ); } else { return super.toEvaluator(toEvaluator); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java index 74394d796855f..e3248665ad486 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/EsqlArithmeticOperation.java @@ -12,6 +12,7 @@ import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.BinaryArithmeticOperation; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -120,8 +121,8 @@ public interface BinaryEvaluator { } @Override - public Object fold() { - return EvaluatorMapper.super.fold(); + public Object fold(FoldContext ctx) { + return EvaluatorMapper.super.fold(source(), ctx); } public DataType dataType() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java index fb32282005f02..6663ccf0ef7b6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Neg.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.ExceptionUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -87,12 +88,12 @@ else if (type == DataType.LONG) { } @Override - public final Object fold() { + public final Object fold(FoldContext ctx) { DataType dataType = field().dataType(); // For date periods and time durations, we need to treat folding differently. These types are unrepresentable, so there is no // evaluator for them - but the default folding requires an evaluator. if (dataType == DATE_PERIOD) { - Period fieldValue = (Period) field().fold(); + Period fieldValue = (Period) field().fold(ctx); try { return fieldValue.negated(); } catch (ArithmeticException e) { @@ -102,7 +103,7 @@ public final Object fold() { } } if (dataType == TIME_DURATION) { - Duration fieldValue = (Duration) field().fold(); + Duration fieldValue = (Duration) field().fold(ctx); try { return fieldValue.negated(); } catch (ArithmeticException e) { @@ -111,7 +112,7 @@ public final Object fold() { throw ExceptionUtils.math(source(), e); } } - return super.fold(); + return super.fold(ctx); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java index 3e2a21664aa7e..e56c19b26a902 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/EsqlBinaryComparison.java @@ -13,6 +13,7 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -204,8 +205,8 @@ public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvalua } @Override - public Boolean fold() { - return (Boolean) EvaluatorMapper.super.fold(); + public Boolean fold(FoldContext ctx) { + return (Boolean) EvaluatorMapper.super.fold(source(), ctx); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java index f596d589cdde2..2061c2626aa45 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Comparisons; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -205,11 +206,11 @@ public boolean foldable() { } @Override - public Object fold() { + public Object fold(FoldContext ctx) { if (Expressions.isGuaranteedNull(value) || list.stream().allMatch(Expressions::isGuaranteedNull)) { return null; } - return super.fold(); + return super.fold(ctx); } protected boolean areCompatible(DataType left, DataType right) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java index c731e44197f2e..01564644bf5c7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEquals.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.ann.Evaluator; import org.elasticsearch.compute.ann.Fixed; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -77,9 +78,9 @@ public static Automaton automaton(BytesRef val) { } @Override - public Boolean fold() { - BytesRef leftVal = BytesRefs.toBytesRef(left().fold()); - BytesRef rightVal = BytesRefs.toBytesRef(right().fold()); + public Boolean fold(FoldContext ctx) { + BytesRef leftVal = BytesRefs.toBytesRef(left().fold(ctx)); + BytesRef rightVal = BytesRefs.toBytesRef(right().fold(ctx)); if (leftVal == null || rightVal == null) { return null; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java index f5704239993f9..7ea95c764f36c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsMapper.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.evaluator.mapper.ExpressionMapper; @@ -28,15 +29,15 @@ public class InsensitiveEqualsMapper extends ExpressionMapper InsensitiveEqualsEvaluator.Factory::new; @Override - public final ExpressionEvaluator.Factory map(InsensitiveEquals bc, Layout layout) { + public final ExpressionEvaluator.Factory map(FoldContext foldCtx, InsensitiveEquals bc, Layout layout) { DataType leftType = bc.left().dataType(); DataType rightType = bc.right().dataType(); - var leftEval = toEvaluator(bc.left(), layout); - var rightEval = toEvaluator(bc.right(), layout); + var leftEval = toEvaluator(foldCtx, bc.left(), layout); + var rightEval = toEvaluator(foldCtx, bc.right(), layout); if (DataType.isString(leftType)) { if (bc.right().foldable() && DataType.isString(rightType)) { - BytesRef rightVal = BytesRefs.toBytesRef(bc.right().fold()); + BytesRef rightVal = BytesRefs.toBytesRef(bc.right().fold(FoldContext.small() /* TODO remove me */)); Automaton automaton = InsensitiveEquals.automaton(rightVal); return dvrCtx -> new InsensitiveEqualsConstantEvaluator( bc.source(), @@ -51,13 +52,14 @@ public final ExpressionEvaluator.Factory map(InsensitiveEquals bc, Layout layout } public static ExpressionEvaluator.Factory castToEvaluator( + FoldContext foldCtx, InsensitiveEquals op, Layout layout, DataType required, TriFunction factory ) { - var lhs = Cast.cast(op.source(), op.left().dataType(), required, toEvaluator(op.left(), layout)); - var rhs = Cast.cast(op.source(), op.right().dataType(), required, toEvaluator(op.right(), layout)); + var lhs = Cast.cast(op.source(), op.left().dataType(), required, toEvaluator(foldCtx, op.left(), layout)); + var rhs = Cast.cast(op.source(), op.right().dataType(), required, toEvaluator(foldCtx, op.right(), layout)); return factory.apply(op.source(), lhs, rhs); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java index ef5cf50c76541..183008f900c5d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalOptimizerContext.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.stats.SearchStats; @@ -15,8 +16,8 @@ public final class LocalLogicalOptimizerContext extends LogicalOptimizerContext { private final SearchStats searchStats; - public LocalLogicalOptimizerContext(Configuration configuration, SearchStats searchStats) { - super(configuration); + public LocalLogicalOptimizerContext(Configuration configuration, FoldContext foldCtx, SearchStats searchStats) { + super(configuration, foldCtx); this.searchStats = searchStats; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java index c11e1a4ec49e4..22e07b45310fb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalOptimizerContext.java @@ -7,7 +7,8 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.session.Configuration; import org.elasticsearch.xpack.esql.stats.SearchStats; -public record LocalPhysicalOptimizerContext(Configuration configuration, SearchStats searchStats) {} +public record LocalPhysicalOptimizerContext(Configuration configuration, FoldContext foldCtx, SearchStats searchStats) {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java index 67148e67cbc19..da2d583674a90 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalOptimizerContext.java @@ -7,37 +7,44 @@ package org.elasticsearch.xpack.esql.optimizer; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.session.Configuration; import java.util.Objects; public class LogicalOptimizerContext { private final Configuration configuration; + private final FoldContext foldCtx; - public LogicalOptimizerContext(Configuration configuration) { + public LogicalOptimizerContext(Configuration configuration, FoldContext foldCtx) { this.configuration = configuration; + this.foldCtx = foldCtx; } public Configuration configuration() { return configuration; } + public FoldContext foldCtx() { + return foldCtx; + } + @Override public boolean equals(Object obj) { if (obj == this) return true; if (obj == null || obj.getClass() != this.getClass()) return false; var that = (LogicalOptimizerContext) obj; - return Objects.equals(this.configuration, that.configuration); + return this.configuration.equals(that.configuration) && this.foldCtx.equals(that.foldCtx); } @Override public int hashCode() { - return Objects.hash(configuration); + return Objects.hash(configuration, foldCtx); } @Override public String toString() { - return "LogicalOptimizerContext[" + "configuration=" + configuration + ']'; + return "LogicalOptimizerContext[configuration=" + configuration + ", foldCtx=" + foldCtx + ']'; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java index 3152f9b574767..5f463f2aa4c78 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsElimination.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; @@ -28,7 +29,7 @@ public BooleanFunctionEqualsElimination() { } @Override - public Expression rule(BinaryComparison bc) { + public Expression rule(BinaryComparison bc, LogicalOptimizerContext ctx) { if ((bc instanceof Equals || bc instanceof NotEquals) && bc.left() instanceof Function) { // for expression "==" or "!=" TRUE/FALSE, return the expression itself or its negated variant diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java index 73d1ea1fb6e8f..e1803872fd606 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplification.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.util.List; @@ -35,7 +36,7 @@ public BooleanSimplification() { } @Override - public Expression rule(ScalarFunction e) { + public Expression rule(ScalarFunction e, LogicalOptimizerContext ctx) { if (e instanceof And || e instanceof Or) { return simplifyAndOr((BinaryPredicate) e); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java index 3f47c74aaf814..1c290a7c4c4fd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisons.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; @@ -20,6 +21,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.util.ArrayList; import java.util.List; @@ -31,17 +33,17 @@ public CombineBinaryComparisons() { } @Override - public Expression rule(BinaryLogic e) { + public Expression rule(BinaryLogic e, LogicalOptimizerContext ctx) { if (e instanceof And and) { - return combine(and); + return combine(ctx.foldCtx(), and); } else if (e instanceof Or or) { - return combine(or); + return combine(ctx.foldCtx(), or); } return e; } // combine conjunction - private static Expression combine(And and) { + private static Expression combine(FoldContext ctx, And and) { List bcs = new ArrayList<>(); List exps = new ArrayList<>(); boolean changed = false; @@ -58,13 +60,13 @@ private static Expression combine(And and) { }); for (Expression ex : andExps) { if (ex instanceof BinaryComparison bc && (ex instanceof Equals || ex instanceof NotEquals) == false) { - if (bc.right().foldable() && (findExistingComparison(bc, bcs, true))) { + if (bc.right().foldable() && (findExistingComparison(ctx, bc, bcs, true))) { changed = true; } else { bcs.add(bc); } } else if (ex instanceof NotEquals neq) { - if (neq.right().foldable() && notEqualsIsRemovableFromConjunction(neq, bcs)) { + if (neq.right().foldable() && notEqualsIsRemovableFromConjunction(ctx, neq, bcs)) { // the non-equality can simply be dropped: either superfluous or has been merged with an updated range/inequality changed = true; } else { // not foldable OR not overlapping @@ -78,13 +80,13 @@ private static Expression combine(And and) { } // combine disjunction - private static Expression combine(Or or) { + private static Expression combine(FoldContext ctx, Or or) { List bcs = new ArrayList<>(); List exps = new ArrayList<>(); boolean changed = false; for (Expression ex : Predicates.splitOr(or)) { if (ex instanceof BinaryComparison bc) { - if (bc.right().foldable() && findExistingComparison(bc, bcs, false)) { + if (bc.right().foldable() && findExistingComparison(ctx, bc, bcs, false)) { changed = true; } else { bcs.add(bc); @@ -100,8 +102,8 @@ private static Expression combine(Or or) { * Find commonalities between the given comparison in the given list. * The method can be applied both for conjunctive (AND) or disjunctive purposes (OR). */ - private static boolean findExistingComparison(BinaryComparison main, List bcs, boolean conjunctive) { - Object value = main.right().fold(); + private static boolean findExistingComparison(FoldContext ctx, BinaryComparison main, List bcs, boolean conjunctive) { + Object value = main.right().fold(ctx); // NB: the loop modifies the list (hence why the int is used) for (int i = 0; i < bcs.size(); i++) { BinaryComparison other = bcs.get(i); @@ -113,7 +115,7 @@ private static boolean findExistingComparison(BinaryComparison main, List bcs) { - Object neqVal = notEquals.right().fold(); + private static boolean notEqualsIsRemovableFromConjunction(FoldContext ctx, NotEquals notEquals, List bcs) { + Object neqVal = notEquals.right().fold(ctx); Integer comp; // check on "condition-overlapping" inequalities: @@ -183,7 +185,7 @@ private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals, BinaryComparison bc = bcs.get(i); if (notEquals.left().semanticEquals(bc.left())) { if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { - comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null; + comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold(ctx)) : null; if (comp != null) { if (comp >= 0) { if (comp == 0 && bc instanceof LessThanOrEqual) { // a != 2 AND a <= 2 -> a < 2 @@ -193,7 +195,7 @@ private static boolean notEqualsIsRemovableFromConjunction(NotEquals notEquals, } // else: comp < 0 : a != 2 AND a nop } // else: non-comparable, nop } else if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) { - comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold()) : null; + comp = bc.right().foldable() ? BinaryComparison.compare(neqVal, bc.right().fold(ctx)) : null; if (comp != null) { if (comp <= 0) { if (comp == 0 && bc instanceof GreaterThanOrEqual) { // a != 2 AND a >= 2 -> a > 2 diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java index 5cb377de47efc..e1cda9cb149d4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctions.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.ip.CIDRMatch; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.time.ZoneId; import java.util.ArrayList; @@ -61,7 +62,7 @@ protected static CIDRMatch createCIDRMatch(Expression k, List v) { } @Override - public Expression rule(Or or) { + public Expression rule(Or or, LogicalOptimizerContext ctx) { Expression e = or; // look only at equals, In and CIDRMatch List exps = splitOr(e); @@ -78,7 +79,7 @@ public Expression rule(Or or) { if (eq.right().foldable()) { ins.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right()); if (eq.left().dataType() == DataType.IP) { - Object value = eq.right().fold(); + Object value = eq.right().fold(ctx.foldCtx()); // ImplicitCasting and ConstantFolding(includes explicit casting) are applied before CombineDisjunctions. // They fold the input IP string to an internal IP format. These happen to Equals and IN, but not for CIDRMatch, // as CIDRMatch takes strings as input, ImplicitCasting does not apply to it, and the first input to CIDRMatch is a @@ -101,7 +102,7 @@ public Expression rule(Or or) { if (in.value().dataType() == DataType.IP) { List values = new ArrayList<>(in.list().size()); for (Expression i : in.list()) { - Object value = i.fold(); + Object value = i.fold(ctx.foldCtx()); // Same as Equals. if (value instanceof BytesRef bytesRef) { value = ipToString(bytesRef); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java index 82fe2c6bddf50..27eec8de59020 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFolding.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; public final class ConstantFolding extends OptimizerRules.OptimizerExpressionRule { @@ -17,7 +18,7 @@ public ConstantFolding() { } @Override - public Expression rule(Expression e) { - return e.foldable() ? Literal.of(e) : e; + public Expression rule(Expression e, LogicalOptimizerContext ctx) { + return e.foldable() ? Literal.of(ctx.foldCtx(), e) : e; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java index 0604750883f14..b716d8f012d21 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConvertStringToByteRef.java @@ -10,6 +10,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.util.ArrayList; import java.util.List; @@ -21,7 +22,8 @@ public ConvertStringToByteRef() { } @Override - protected Expression rule(Literal lit) { + protected Expression rule(Literal lit, LogicalOptimizerContext ctx) { + // TODO we shouldn't be emitting String into Literals at all Object value = lit.value(); if (value == null) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java index 747864625e65c..cf4c7f19baafe 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; public class FoldNull extends OptimizerRules.OptimizerExpressionRule { @@ -23,7 +24,7 @@ public FoldNull() { } @Override - public Expression rule(Expression e) { + public Expression rule(Expression e, LogicalOptimizerContext ctx) { Expression result = tryReplaceIsNullIsNotNull(e); // convert an aggregate null filter into a false diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java index d96c73d5ee4f1..6504e6042c33a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRight.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; public final class LiteralsOnTheRight extends OptimizerRules.OptimizerExpressionRule> { @@ -17,7 +18,7 @@ public LiteralsOnTheRight() { } @Override - public BinaryOperator rule(BinaryOperator be) { + public BinaryOperator rule(BinaryOperator be, LogicalOptimizerContext ctx) { return be.left() instanceof Literal && (be.right() instanceof Literal) == false ? be.swapLeftAndRight() : be; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java index 2a0b2a6af36aa..169ac2ac8c0fe 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/OptimizerRules.java @@ -8,6 +8,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.rule.Rule; @@ -36,7 +37,10 @@ public final LogicalPlan apply(LogicalPlan plan) { protected abstract LogicalPlan rule(SubPlan plan); } - public abstract static class OptimizerExpressionRule extends Rule { + public abstract static class OptimizerExpressionRule extends ParameterizedRule< + LogicalPlan, + LogicalPlan, + LogicalOptimizerContext> { private final TransformDirection direction; // overriding type token which returns the correct class but does an uncheck cast to LogicalPlan due to its generic bound @@ -49,17 +53,13 @@ public OptimizerExpressionRule(TransformDirection direction) { } @Override - public final LogicalPlan apply(LogicalPlan plan) { + public final LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) { return direction == TransformDirection.DOWN - ? plan.transformExpressionsDown(expressionTypeToken, this::rule) - : plan.transformExpressionsUp(expressionTypeToken, this::rule); + ? plan.transformExpressionsDown(expressionTypeToken, e -> rule(e, ctx)) + : plan.transformExpressionsUp(expressionTypeToken, e -> rule(e, ctx)); } - protected LogicalPlan rule(LogicalPlan plan) { - return plan; - } - - protected abstract Expression rule(E e); + protected abstract Expression rule(E e, LogicalOptimizerContext ctx); public Class expressionToken() { return expressionTypeToken; @@ -82,6 +82,7 @@ protected ParameterizedOptimizerRule(TransformDirection direction) { this.direction = direction; } + @Override public final LogicalPlan apply(LogicalPlan plan, P context) { return direction == TransformDirection.DOWN ? plan.transformDown(typeToken(), t -> rule(t, context)) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java index 118e4fc170520..0111c7cdd806a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PartiallyFoldCase.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; @@ -28,7 +29,7 @@ public PartiallyFoldCase() { } @Override - protected Expression rule(Case c) { - return c.partiallyFold(); + protected Expression rule(Case c, LogicalOptimizerContext ctx) { + return c.partiallyFold(ctx.foldCtx()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java index 8437b79454884..b6f185c856693 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEmptyRelation.java @@ -12,9 +12,11 @@ import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; @@ -26,15 +28,18 @@ import java.util.List; @SuppressWarnings("removal") -public class PropagateEmptyRelation extends OptimizerRules.OptimizerRule { +public class PropagateEmptyRelation extends OptimizerRules.ParameterizedOptimizerRule { + public PropagateEmptyRelation() { + super(OptimizerRules.TransformDirection.DOWN); + } @Override - protected LogicalPlan rule(UnaryPlan plan) { + protected LogicalPlan rule(UnaryPlan plan, LogicalOptimizerContext ctx) { LogicalPlan p = plan; if (plan.child() instanceof LocalRelation local && local.supplier() == LocalSupplier.EMPTY) { // only care about non-grouped aggs might return something (count) if (plan instanceof Aggregate agg && agg.groupings().isEmpty()) { - List emptyBlocks = aggsFromEmpty(agg.aggregates()); + List emptyBlocks = aggsFromEmpty(ctx.foldCtx(), agg.aggregates()); p = replacePlanByRelation(plan, LocalSupplier.of(emptyBlocks.toArray(Block[]::new))); } else { p = PruneEmptyPlans.skipPlan(plan); @@ -43,14 +48,14 @@ protected LogicalPlan rule(UnaryPlan plan) { return p; } - private List aggsFromEmpty(List aggs) { + private List aggsFromEmpty(FoldContext foldCtx, List aggs) { List blocks = new ArrayList<>(); var blockFactory = PlannerUtils.NON_BREAKING_BLOCK_FACTORY; int i = 0; for (var agg : aggs) { // there needs to be an alias if (Alias.unwrap(agg) instanceof AggregateFunction aggFunc) { - aggOutput(agg, aggFunc, blockFactory, blocks); + aggOutput(foldCtx, agg, aggFunc, blockFactory, blocks); } else { throw new EsqlIllegalArgumentException("Did not expect a non-aliased aggregation {}", agg); } @@ -61,9 +66,15 @@ private List aggsFromEmpty(List aggs) { /** * The folded aggregation output - this variant is for the coordinator/final. */ - protected void aggOutput(NamedExpression agg, AggregateFunction aggFunc, BlockFactory blockFactory, List blocks) { + protected void aggOutput( + FoldContext foldCtx, + NamedExpression agg, + AggregateFunction aggFunc, + BlockFactory blockFactory, + List blocks + ) { // look for count(literal) with literal != null - Object value = aggFunc instanceof Count count && (count.foldable() == false || count.fold() != null) ? 0L : null; + Object value = aggFunc instanceof Count count && (count.foldable() == false || count.fold(foldCtx) != null) ? 0L : null; var wrapper = BlockUtils.wrapperFor(blockFactory, PlannerUtils.toElementType(aggFunc.dataType()), 1); wrapper.accept(value); blocks.add(wrapper.builder().build()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java index 0bd98db1e1d7a..5a1677f2759e3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEquals.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.util.ArrayList; import java.util.Iterator; @@ -41,17 +42,18 @@ public PropagateEquals() { super(OptimizerRules.TransformDirection.DOWN); } - public Expression rule(BinaryLogic e) { + @Override + public Expression rule(BinaryLogic e, LogicalOptimizerContext ctx) { if (e instanceof And) { - return propagate((And) e); + return propagate((And) e, ctx); } else if (e instanceof Or) { - return propagate((Or) e); + return propagate((Or) e, ctx); } return e; } // combine conjunction - private static Expression propagate(And and) { + private static Expression propagate(And and, LogicalOptimizerContext ctx) { List ranges = new ArrayList<>(); // Only equalities, not-equalities and inequalities with a foldable .right are extracted separately; // the others go into the general 'exps'. @@ -72,7 +74,7 @@ private static Expression propagate(And and) { if (otherEq.right().foldable() && DataType.isDateTime(otherEq.left().dataType()) == false) { for (BinaryComparison eq : equals) { if (otherEq.left().semanticEquals(eq.left())) { - Integer comp = BinaryComparison.compare(eq.right().fold(), otherEq.right().fold()); + Integer comp = BinaryComparison.compare(eq.right().fold(ctx.foldCtx()), otherEq.right().fold(ctx.foldCtx())); if (comp != null) { // var cannot be equal to two different values at the same time if (comp != 0) { @@ -108,7 +110,7 @@ private static Expression propagate(And and) { // check for (BinaryComparison eq : equals) { - Object eqValue = eq.right().fold(); + Object eqValue = eq.right().fold(ctx.foldCtx()); for (Iterator iterator = ranges.iterator(); iterator.hasNext();) { Range range = iterator.next(); @@ -116,7 +118,7 @@ private static Expression propagate(And and) { if (range.value().semanticEquals(eq.left())) { // if equals is outside the interval, evaluate the whole expression to FALSE if (range.lower().foldable()) { - Integer compare = BinaryComparison.compare(range.lower().fold(), eqValue); + Integer compare = BinaryComparison.compare(range.lower().fold(ctx.foldCtx()), eqValue); if (compare != null && ( // eq outside the lower boundary compare > 0 || @@ -126,7 +128,7 @@ private static Expression propagate(And and) { } } if (range.upper().foldable()) { - Integer compare = BinaryComparison.compare(range.upper().fold(), eqValue); + Integer compare = BinaryComparison.compare(range.upper().fold(ctx.foldCtx()), eqValue); if (compare != null && ( // eq outside the upper boundary compare < 0 || @@ -146,7 +148,7 @@ private static Expression propagate(And and) { for (Iterator iter = notEquals.iterator(); iter.hasNext();) { NotEquals neq = iter.next(); if (eq.left().semanticEquals(neq.left())) { - Integer comp = BinaryComparison.compare(eqValue, neq.right().fold()); + Integer comp = BinaryComparison.compare(eqValue, neq.right().fold(ctx.foldCtx())); if (comp != null) { if (comp == 0) { // clashing and conflicting: a = 1 AND a != 1 return new Literal(and.source(), Boolean.FALSE, DataType.BOOLEAN); @@ -162,7 +164,7 @@ private static Expression propagate(And and) { for (Iterator iter = inequalities.iterator(); iter.hasNext();) { BinaryComparison bc = iter.next(); if (eq.left().semanticEquals(bc.left())) { - Integer compare = BinaryComparison.compare(eqValue, bc.right().fold()); + Integer compare = BinaryComparison.compare(eqValue, bc.right().fold(ctx.foldCtx())); if (compare != null) { if (bc instanceof LessThan || bc instanceof LessThanOrEqual) { // a = 2 AND a a < 3; a = 2 OR a < 1 -> nop // a = 2 OR 3 < a < 5 -> nop; a = 2 OR 1 < a < 3 -> 1 < a < 3; a = 2 OR 0 < a < 1 -> nop // a = 2 OR a != 2 -> TRUE; a = 2 OR a = 5 -> nop; a = 2 OR a != 5 -> a != 5 - private static Expression propagate(Or or) { + private static Expression propagate(Or or, LogicalOptimizerContext ctx) { List exps = new ArrayList<>(); List equals = new ArrayList<>(); // foldable right term Equals List notEquals = new ArrayList<>(); // foldable right term NotEquals @@ -230,13 +232,13 @@ private static Expression propagate(Or or) { // evaluate the impact of each Equal over the different types of Expressions for (Iterator iterEq = equals.iterator(); iterEq.hasNext();) { Equals eq = iterEq.next(); - Object eqValue = eq.right().fold(); + Object eqValue = eq.right().fold(ctx.foldCtx()); boolean removeEquals = false; // Equals OR NotEquals for (NotEquals neq : notEquals) { if (eq.left().semanticEquals(neq.left())) { // a = 2 OR a != ? -> ... - Integer comp = BinaryComparison.compare(eqValue, neq.right().fold()); + Integer comp = BinaryComparison.compare(eqValue, neq.right().fold(ctx.foldCtx())); if (comp != null) { if (comp == 0) { // a = 2 OR a != 2 -> TRUE return TRUE; @@ -257,8 +259,12 @@ private static Expression propagate(Or or) { for (int i = 0; i < ranges.size(); i++) { // might modify list, so use index loop Range range = ranges.get(i); if (eq.left().semanticEquals(range.value())) { - Integer lowerComp = range.lower().foldable() ? BinaryComparison.compare(eqValue, range.lower().fold()) : null; - Integer upperComp = range.upper().foldable() ? BinaryComparison.compare(eqValue, range.upper().fold()) : null; + Integer lowerComp = range.lower().foldable() + ? BinaryComparison.compare(eqValue, range.lower().fold(ctx.foldCtx())) + : null; + Integer upperComp = range.upper().foldable() + ? BinaryComparison.compare(eqValue, range.upper().fold(ctx.foldCtx())) + : null; if (lowerComp != null && lowerComp == 0) { if (range.includeLower() == false) { // a = 2 OR 2 < a < ? -> 2 <= a < ? @@ -312,7 +318,7 @@ private static Expression propagate(Or or) { for (int i = 0; i < inequalities.size(); i++) { BinaryComparison bc = inequalities.get(i); if (eq.left().semanticEquals(bc.left())) { - Integer comp = BinaryComparison.compare(eqValue, bc.right().fold()); + Integer comp = BinaryComparison.compare(eqValue, bc.right().fold(ctx.foldCtx())); if (comp != null) { if (bc instanceof GreaterThan || bc instanceof GreaterThanOrEqual) { if (comp < 0) { // a = 1 OR a > 2 -> nop diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java index 73eaa9220fd84..66cdc992a91cb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEvalFoldables.java @@ -12,19 +12,20 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.rule.Rule; +import org.elasticsearch.xpack.esql.rule.ParameterizedRule; /** * Replace any reference attribute with its source, if it does not affect the result. * This avoids ulterior look-ups between attributes and its source across nodes. */ -public final class PropagateEvalFoldables extends Rule { +public final class PropagateEvalFoldables extends ParameterizedRule { @Override - public LogicalPlan apply(LogicalPlan plan) { + public LogicalPlan apply(LogicalPlan plan, LogicalOptimizerContext ctx) { var collectRefs = new AttributeMap(); java.util.function.Function replaceReference = r -> collectRefs.resolve(r, r); @@ -39,7 +40,7 @@ public LogicalPlan apply(LogicalPlan plan) { shouldCollect = c.foldable(); } if (shouldCollect) { - collectRefs.put(a.toAttribute(), Literal.of(c)); + collectRefs.put(a.toAttribute(), Literal.of(ctx.foldCtx(), c)); } }); if (collectRefs.isEmpty()) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java index 738ca83b47e42..e3165180e331c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullable.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.util.ArrayList; import java.util.LinkedHashSet; @@ -33,7 +34,7 @@ public PropagateNullable() { } @Override - public Expression rule(And and) { + public Expression rule(And and, LogicalOptimizerContext ctx) { List splits = Predicates.splitAnd(and); Set nullExpressions = new LinkedHashSet<>(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java index 1cacebdf27cd2..969a6bb713eca 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PushDownAndCombineLimits.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Eval; @@ -20,14 +21,18 @@ import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; -public final class PushDownAndCombineLimits extends OptimizerRules.OptimizerRule { +public final class PushDownAndCombineLimits extends OptimizerRules.ParameterizedOptimizerRule { + + public PushDownAndCombineLimits() { + super(OptimizerRules.TransformDirection.DOWN); + } @Override - public LogicalPlan rule(Limit limit) { + public LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) { if (limit.child() instanceof Limit childLimit) { var limitSource = limit.limit(); - var l1 = (int) limitSource.fold(); - var l2 = (int) childLimit.limit().fold(); + var l1 = (int) limitSource.fold(ctx.foldCtx()); + var l2 = (int) childLimit.limit().fold(ctx.foldCtx()); return new Limit(limit.source(), Literal.of(limitSource, Math.min(l1, l2)), childLimit.child()); } else if (limit.child() instanceof UnaryPlan unary) { if (unary instanceof Eval || unary instanceof Project || unary instanceof RegexExtract || unary instanceof Enrich) { @@ -41,7 +46,7 @@ public LogicalPlan rule(Limit limit) { // we add an inner limit to MvExpand and just push down the existing limit, ie. // | MV_EXPAND | LIMIT N -> | LIMIT N | MV_EXPAND (with limit N) var limitSource = limit.limit(); - var limitVal = (int) limitSource.fold(); + var limitVal = (int) limitSource.fold(ctx.foldCtx()); Integer mvxLimit = mvx.limit(); if (mvxLimit == null || mvxLimit > limitVal) { mvx = new MvExpand(mvx.source(), mvx.child(), mvx.target(), mvx.expanded(), limitVal); @@ -54,8 +59,8 @@ public LogicalPlan rule(Limit limit) { else { Limit descendantLimit = descendantLimit(unary); if (descendantLimit != null) { - var l1 = (int) limit.limit().fold(); - var l2 = (int) descendantLimit.limit().fold(); + var l1 = (int) limit.limit().fold(ctx.foldCtx()); + var l2 = (int) descendantLimit.limit().fold(ctx.foldCtx()); if (l2 <= l1) { return new Limit(limit.source(), Literal.of(limit.limit(), l2), limit.child()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java index 1a8f8a164cc1b..7953b2b28eaaa 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatch.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.regex.StringPattern; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.parser.ParsingException; public final class ReplaceRegexMatch extends OptimizerRules.OptimizerExpressionRule> { @@ -23,7 +24,7 @@ public ReplaceRegexMatch() { } @Override - public Expression rule(RegexMatch regexMatch) { + public Expression rule(RegexMatch regexMatch, LogicalOptimizerContext ctx) { Expression e = regexMatch; StringPattern pattern = regexMatch.pattern(); boolean matchesAll; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java index eebeb1dc14f48..9e7b6ce80422d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRowAsLocalRelation.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.compute.data.BlockUtils; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; @@ -17,13 +18,16 @@ import java.util.ArrayList; import java.util.List; -public final class ReplaceRowAsLocalRelation extends OptimizerRules.OptimizerRule { +public final class ReplaceRowAsLocalRelation extends OptimizerRules.ParameterizedOptimizerRule { + public ReplaceRowAsLocalRelation() { + super(OptimizerRules.TransformDirection.DOWN); + } @Override - protected LogicalPlan rule(Row row) { + protected LogicalPlan rule(Row row, LogicalOptimizerContext context) { var fields = row.fields(); List values = new ArrayList<>(fields.size()); - fields.forEach(f -> values.add(f.child().fold())); + fields.forEach(f -> values.add(f.child().fold(context.foldCtx()))); var blocks = BlockUtils.fromListRow(PlannerUtils.NON_BREAKING_BLOCK_FACTORY, values); return new LocalRelation(row.source(), row.output(), LocalSupplier.of(blocks)); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java index 2cafcc2e07052..a7e56a5f25fc8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStatsFilteredAggWithEval.java @@ -49,7 +49,7 @@ protected LogicalPlan rule(Aggregate aggregate) { && alias.child() instanceof AggregateFunction aggFunction && aggFunction.hasFilter() && aggFunction.filter() instanceof Literal literal - && Boolean.FALSE.equals(literal.fold())) { + && Boolean.FALSE.equals(literal.value())) { Object value = aggFunction instanceof Count || aggFunction instanceof CountDistinct ? 0L : null; Alias newAlias = alias.replaceChild(Literal.of(aggFunction, value)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java index 0fea7cf8ddc1f..053441bce5e1f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceStringCasingWithInsensitiveEquals.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.InsensitiveEquals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; public class ReplaceStringCasingWithInsensitiveEquals extends OptimizerRules.OptimizerExpressionRule { @@ -26,30 +27,35 @@ public ReplaceStringCasingWithInsensitiveEquals() { } @Override - protected Expression rule(ScalarFunction sf) { + protected Expression rule(ScalarFunction sf, LogicalOptimizerContext ctx) { Expression e = sf; if (sf instanceof BinaryComparison bc) { - e = rewriteBinaryComparison(sf, bc, false); + e = rewriteBinaryComparison(ctx, sf, bc, false); } else if (sf instanceof Not not && not.field() instanceof BinaryComparison bc) { - e = rewriteBinaryComparison(sf, bc, true); + e = rewriteBinaryComparison(ctx, sf, bc, true); } return e; } - private static Expression rewriteBinaryComparison(ScalarFunction sf, BinaryComparison bc, boolean negated) { + private static Expression rewriteBinaryComparison( + LogicalOptimizerContext ctx, + ScalarFunction sf, + BinaryComparison bc, + boolean negated + ) { Expression e = sf; if (bc.left() instanceof ChangeCase changeCase && bc.right().foldable()) { if (bc instanceof Equals) { - e = replaceChangeCase(bc, changeCase, negated); + e = replaceChangeCase(ctx, bc, changeCase, negated); } else if (bc instanceof NotEquals) { // not actually used currently, `!=` is built as `NOT(==)` already - e = replaceChangeCase(bc, changeCase, negated == false); + e = replaceChangeCase(ctx, bc, changeCase, negated == false); } } return e; } - private static Expression replaceChangeCase(BinaryComparison bc, ChangeCase changeCase, boolean negated) { - var foldedRight = BytesRefs.toString(bc.right().fold()); + private static Expression replaceChangeCase(LogicalOptimizerContext ctx, BinaryComparison bc, ChangeCase changeCase, boolean negated) { + var foldedRight = BytesRefs.toString(bc.right().fold(ctx.foldCtx())); var field = unwrapCase(changeCase.field()); var e = changeCase.caseType().matchesCase(foldedRight) ? new InsensitiveEquals(bc.source(), field, bc.right()) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java index d3a9970896c16..60ff161651f2d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SimplifyComparisonsArithmetics.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -15,6 +16,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.BinaryComparisonInversible; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.time.DateTimeException; import java.util.List; @@ -41,20 +43,20 @@ public SimplifyComparisonsArithmetics(BiFunction ty } @Override - protected Expression rule(BinaryComparison bc) { + protected Expression rule(BinaryComparison bc, LogicalOptimizerContext ctx) { // optimize only once the expression has a literal on the right side of the binary comparison if (bc.right() instanceof Literal) { if (bc.left() instanceof ArithmeticOperation) { - return simplifyBinaryComparison(bc); + return simplifyBinaryComparison(ctx.foldCtx(), bc); } if (bc.left() instanceof Neg) { - return foldNegation(bc); + return foldNegation(ctx.foldCtx(), bc); } } return bc; } - private Expression simplifyBinaryComparison(BinaryComparison comparison) { + private Expression simplifyBinaryComparison(FoldContext foldContext, BinaryComparison comparison) { ArithmeticOperation operation = (ArithmeticOperation) comparison.left(); // Use symbol comp: SQL operations aren't available in this package (as dependencies) String opSymbol = operation.symbol(); @@ -64,9 +66,9 @@ private Expression simplifyBinaryComparison(BinaryComparison comparison) { } OperationSimplifier simplification = null; if (isMulOrDiv(opSymbol)) { - simplification = new MulDivSimplifier(comparison); + simplification = new MulDivSimplifier(foldContext, comparison); } else if (opSymbol.equals(ADD.symbol()) || opSymbol.equals(SUB.symbol())) { - simplification = new AddSubSimplifier(comparison); + simplification = new AddSubSimplifier(foldContext, comparison); } return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply(); @@ -76,16 +78,16 @@ private static boolean isMulOrDiv(String opSymbol) { return opSymbol.equals(MUL.symbol()) || opSymbol.equals(DIV.symbol()); } - private static Expression foldNegation(BinaryComparison bc) { + private static Expression foldNegation(FoldContext ctx, BinaryComparison bc) { Literal bcLiteral = (Literal) bc.right(); - Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral)); + Expression literalNeg = tryFolding(ctx, new Neg(bcLiteral.source(), bcLiteral)); return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg)); } - private static Expression tryFolding(Expression expression) { + private static Expression tryFolding(FoldContext ctx, Expression expression) { if (expression.foldable()) { try { - expression = new Literal(expression.source(), expression.fold(), expression.dataType()); + expression = new Literal(expression.source(), expression.fold(ctx), expression.dataType()); } catch (ArithmeticException | DateTimeException e) { // null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped. expression = null; @@ -95,6 +97,7 @@ private static Expression tryFolding(Expression expression) { } private abstract static class OperationSimplifier { + final FoldContext foldContext; final BinaryComparison comparison; final Literal bcLiteral; final ArithmeticOperation operation; @@ -102,7 +105,8 @@ private abstract static class OperationSimplifier { final Expression opRight; final Literal opLiteral; - OperationSimplifier(BinaryComparison comparison) { + OperationSimplifier(FoldContext foldContext, BinaryComparison comparison) { + this.foldContext = foldContext; this.comparison = comparison; operation = (ArithmeticOperation) comparison.left(); bcLiteral = (Literal) comparison.right(); @@ -151,7 +155,7 @@ final Expression apply() { Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse() .create(bcl.source(), bcl, opRight); - bcRightExpression = tryFolding(bcRightExpression); + bcRightExpression = tryFolding(foldContext, bcRightExpression); return bcRightExpression != null ? postProcess((BinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression))) : comparison; @@ -169,8 +173,8 @@ Expression postProcess(BinaryComparison binaryComparison) { private static class AddSubSimplifier extends OperationSimplifier { - AddSubSimplifier(BinaryComparison comparison) { - super(comparison); + AddSubSimplifier(FoldContext foldContext, BinaryComparison comparison) { + super(foldContext, comparison); } @Override @@ -182,7 +186,7 @@ boolean isOpUnsafe() { if (operation.symbol().equals(SUB.symbol()) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX // if next simplification step would fail on overflow anyways, skip the optimisation already - return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null; + return tryFolding(foldContext, new Sub(EMPTY, opLeft, bcLiteral)) == null; } return false; @@ -194,8 +198,8 @@ private static class MulDivSimplifier extends OperationSimplifier { private final boolean isDiv; // and not MUL. private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal) - MulDivSimplifier(BinaryComparison comparison) { - super(comparison); + MulDivSimplifier(FoldContext foldContext, BinaryComparison comparison) { + super(foldContext, comparison); isDiv = operation.symbol().equals(DIV.symbol()); opRightSign = sign(opRight); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java index 5d98d941bb207..c6d62dee0ba42 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SkipQueryOnLimitZero.java @@ -7,14 +7,19 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -public final class SkipQueryOnLimitZero extends OptimizerRules.OptimizerRule { +public final class SkipQueryOnLimitZero extends OptimizerRules.ParameterizedOptimizerRule { + public SkipQueryOnLimitZero() { + super(OptimizerRules.TransformDirection.DOWN); + } + @Override - protected LogicalPlan rule(Limit limit) { + protected LogicalPlan rule(Limit limit, LogicalOptimizerContext ctx) { if (limit.limit().foldable()) { - if (Integer.valueOf(0).equals((limit.limit().fold()))) { + if (Integer.valueOf(0).equals((limit.limit().fold(ctx.foldCtx())))) { return PruneEmptyPlans.skipPlan(limit); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java index 9e9ae6a9a559d..870464feb4867 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import java.util.ArrayList; import java.util.List; @@ -25,7 +26,7 @@ public SplitInWithFoldableValue() { } @Override - public Expression rule(In in) { + public Expression rule(In in, LogicalOptimizerContext ctx) { if (in.value().foldable()) { List foldables = new ArrayList<>(in.list().size()); List nonFoldables = new ArrayList<>(in.list().size()); @@ -36,7 +37,7 @@ public Expression rule(In in) { nonFoldables.add(e); } }); - if (foldables.size() > 0 && nonFoldables.size() > 0) { + if (foldables.isEmpty() == false && nonFoldables.isEmpty() == false) { In withFoldables = new In(in.source(), in.value(), foldables); In withoutFoldables = new In(in.source(), in.value(), nonFoldables); return new Or(in.source(), withFoldables, withoutFoldables); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java index c8369d2b08a34..62a00b79d7333 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteFilteredExpression.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.expression.function.aggregate.FilteredExpression; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.OptimizerExpressionRule; import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection; @@ -21,7 +22,7 @@ public SubstituteFilteredExpression() { } @Override - protected Expression rule(FilteredExpression filteredExpression) { + protected Expression rule(FilteredExpression filteredExpression, LogicalOptimizerContext ctx) { return filteredExpression.surrogate(); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java index 93512d80e1708..4b68ee941bc92 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SubstituteSpatialSurrogates.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; /** * Currently this works similarly to SurrogateExpression, leaving the logic inside the expressions, @@ -23,7 +24,7 @@ public SubstituteSpatialSurrogates() { } @Override - protected SpatialRelatesFunction rule(SpatialRelatesFunction function) { + protected SpatialRelatesFunction rule(SpatialRelatesFunction function, LogicalOptimizerContext ctx) { return function.surrogate(); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java index d29da1354ef3c..9259a50d5ff9e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/LocalPropagateEmptyRelation.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockUtils; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; @@ -25,19 +26,24 @@ * Local aggregation can only produce intermediate state that get wired into the global agg. */ public class LocalPropagateEmptyRelation extends PropagateEmptyRelation { - /** * Local variant of the aggregation that returns the intermediate value. */ @Override - protected void aggOutput(NamedExpression agg, AggregateFunction aggFunc, BlockFactory blockFactory, List blocks) { + protected void aggOutput( + FoldContext foldCtx, + NamedExpression agg, + AggregateFunction aggFunc, + BlockFactory blockFactory, + List blocks + ) { List output = AbstractPhysicalOperationProviders.intermediateAttributes(List.of(agg), List.of()); for (Attribute o : output) { DataType dataType = o.dataType(); // boolean right now is used for the internal #seen so always return true var value = dataType == DataType.BOOLEAN ? true // look for count(literal) with literal != null - : aggFunc instanceof Count count && (count.foldable() == false || count.fold() != null) ? 0L + : aggFunc instanceof Count count && (count.foldable() == false || count.fold(foldCtx) != null) ? 0L // otherwise nullify : null; var wrapper = BlockUtils.wrapperFor(blockFactory, PlannerUtils.toElementType(dataType), 1); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java index dfb1dbc8bc8f3..afeab28745c65 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/EnableSpatialDistancePushdown.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -76,22 +77,33 @@ public class EnableSpatialDistancePushdown extends PhysicalOptimizerRules.Parame protected PhysicalPlan rule(FilterExec filterExec, LocalPhysicalOptimizerContext ctx) { PhysicalPlan plan = filterExec; if (filterExec.child() instanceof EsQueryExec esQueryExec) { - plan = rewrite(filterExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats())); + plan = rewrite(ctx.foldCtx(), filterExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats())); } else if (filterExec.child() instanceof EvalExec evalExec && evalExec.child() instanceof EsQueryExec esQueryExec) { - plan = rewriteBySplittingFilter(filterExec, evalExec, esQueryExec, LucenePushdownPredicates.from(ctx.searchStats())); + plan = rewriteBySplittingFilter( + ctx.foldCtx(), + filterExec, + evalExec, + esQueryExec, + LucenePushdownPredicates.from(ctx.searchStats()) + ); } return plan; } - private FilterExec rewrite(FilterExec filterExec, EsQueryExec esQueryExec, LucenePushdownPredicates lucenePushdownPredicates) { + private FilterExec rewrite( + FoldContext ctx, + FilterExec filterExec, + EsQueryExec esQueryExec, + LucenePushdownPredicates lucenePushdownPredicates + ) { // Find and rewrite any binary comparisons that involve a distance function and a literal var rewritten = filterExec.condition().transformDown(EsqlBinaryComparison.class, comparison -> { ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType()); if (comparison.left() instanceof StDistance dist && comparison.right().foldable()) { - return rewriteComparison(comparison, dist, comparison.right(), comparisonType); + return rewriteComparison(ctx, comparison, dist, comparison.right(), comparisonType); } else if (comparison.right() instanceof StDistance dist && comparison.left().foldable()) { - return rewriteComparison(comparison, dist, comparison.left(), ComparisonType.invert(comparisonType)); + return rewriteComparison(ctx, comparison, dist, comparison.left(), ComparisonType.invert(comparisonType)); } return comparison; }); @@ -120,6 +132,7 @@ private FilterExec rewrite(FilterExec filterExec, EsQueryExec esQueryExec, Lucen * */ private PhysicalPlan rewriteBySplittingFilter( + FoldContext ctx, FilterExec filterExec, EvalExec evalExec, EsQueryExec esQueryExec, @@ -142,7 +155,7 @@ private PhysicalPlan rewriteBySplittingFilter( for (Expression exp : splitAnd(filterExec.condition())) { Expression resExp = exp.transformUp(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r)); // Find and rewrite any binary comparisons that involve a distance function and a literal - var rewritten = rewriteDistanceFilters(resExp, distances); + var rewritten = rewriteDistanceFilters(ctx, resExp, distances); // If all pushable StDistance functions were found and re-written, we need to re-write the FILTER/EVAL combination if (rewritten.equals(resExp) == false && canPushToSource(rewritten, lucenePushdownPredicates)) { pushable.add(rewritten); @@ -181,40 +194,42 @@ private Map getPushableDistances(List aliases, Lucene return distances; } - private Expression rewriteDistanceFilters(Expression expr, Map distances) { + private Expression rewriteDistanceFilters(FoldContext ctx, Expression expr, Map distances) { return expr.transformDown(EsqlBinaryComparison.class, comparison -> { ComparisonType comparisonType = ComparisonType.from(comparison.getFunctionType()); if (comparison.left() instanceof ReferenceAttribute r && distances.containsKey(r.id()) && comparison.right().foldable()) { StDistance dist = distances.get(r.id()); - return rewriteComparison(comparison, dist, comparison.right(), comparisonType); + return rewriteComparison(ctx, comparison, dist, comparison.right(), comparisonType); } else if (comparison.right() instanceof ReferenceAttribute r && distances.containsKey(r.id()) && comparison.left().foldable()) { StDistance dist = distances.get(r.id()); - return rewriteComparison(comparison, dist, comparison.left(), ComparisonType.invert(comparisonType)); + return rewriteComparison(ctx, comparison, dist, comparison.left(), ComparisonType.invert(comparisonType)); } return comparison; }); } private Expression rewriteComparison( + FoldContext ctx, EsqlBinaryComparison comparison, StDistance dist, Expression literal, ComparisonType comparisonType ) { - Object value = literal.fold(); + Object value = literal.fold(ctx); if (value instanceof Number number) { if (dist.right().foldable()) { - return rewriteDistanceFilter(comparison, dist.left(), dist.right(), number, comparisonType); + return rewriteDistanceFilter(ctx, comparison, dist.left(), dist.right(), number, comparisonType); } else if (dist.left().foldable()) { - return rewriteDistanceFilter(comparison, dist.right(), dist.left(), number, comparisonType); + return rewriteDistanceFilter(ctx, comparison, dist.right(), dist.left(), number, comparisonType); } } return comparison; } private Expression rewriteDistanceFilter( + FoldContext ctx, EsqlBinaryComparison comparison, Expression spatialExp, Expression literalExp, @@ -222,7 +237,7 @@ private Expression rewriteDistanceFilter( ComparisonType comparisonType ) { DataType shapeDataType = getShapeDataType(spatialExp); - Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(literalExp); + Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(ctx, literalExp); if (geometry instanceof Point point) { double distance = number.doubleValue(); Source source = comparison.source(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java index 1d838fc90bc2c..f354059d65fe8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSource.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -61,7 +62,7 @@ public class PushTopNToSource extends PhysicalOptimizerRules.ParameterizedOptimi @Override protected PhysicalPlan rule(TopNExec topNExec, LocalPhysicalOptimizerContext ctx) { - Pushable pushable = evaluatePushable(topNExec, LucenePushdownPredicates.from(ctx.searchStats())); + Pushable pushable = evaluatePushable(ctx.foldCtx(), topNExec, LucenePushdownPredicates.from(ctx.searchStats())); return pushable.rewrite(topNExec); } @@ -95,18 +96,18 @@ private EsQueryExec.Sort sort() { return new EsQueryExec.GeoDistanceSort(fieldAttribute.exactAttribute(), order.direction(), point.getLat(), point.getLon()); } - private static PushableGeoDistance from(StDistance distance, Order order) { + private static PushableGeoDistance from(FoldContext ctx, StDistance distance, Order order) { if (distance.left() instanceof Attribute attr && distance.right().foldable()) { - return from(attr, distance.right(), order); + return from(ctx, attr, distance.right(), order); } else if (distance.right() instanceof Attribute attr && distance.left().foldable()) { - return from(attr, distance.left(), order); + return from(ctx, attr, distance.left(), order); } return null; } - private static PushableGeoDistance from(Attribute attr, Expression foldable, Order order) { + private static PushableGeoDistance from(FoldContext ctx, Attribute attr, Expression foldable, Order order) { if (attr instanceof FieldAttribute fieldAttribute) { - Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(foldable); + Geometry geometry = SpatialRelatesUtils.makeGeometryFromLiteral(ctx, foldable); if (geometry instanceof Point point) { return new PushableGeoDistance(fieldAttribute, order, point); } @@ -122,7 +123,7 @@ public PhysicalPlan rewrite(TopNExec topNExec) { } } - private static Pushable evaluatePushable(TopNExec topNExec, LucenePushdownPredicates lucenePushdownPredicates) { + private static Pushable evaluatePushable(FoldContext ctx, TopNExec topNExec, LucenePushdownPredicates lucenePushdownPredicates) { PhysicalPlan child = topNExec.child(); if (child instanceof EsQueryExec queryExec && queryExec.canPushSorts() @@ -164,7 +165,7 @@ && canPushDownOrders(topNExec.order(), lucenePushdownPredicates)) { if (distances.containsKey(resolvedAttribute.id())) { StDistance distance = distances.get(resolvedAttribute.id()); StDistance d = (StDistance) distance.transformDown(ReferenceAttribute.class, r -> aliasReplacedBy.resolve(r, r)); - PushableGeoDistance pushableGeoDistance = PushableGeoDistance.from(d, order); + PushableGeoDistance pushableGeoDistance = PushableGeoDistance.from(ctx, d, order); if (pushableGeoDistance != null) { pushableSorts.add(pushableGeoDistance.sort()); } else { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java index 5d70dbf1c2871..96e8135d414e4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/ExpressionBuilder.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; @@ -686,12 +687,20 @@ public Expression visitRegexBooleanExpression(EsqlBaseParser.RegexBooleanExpress RegexMatch result = switch (type) { case EsqlBaseParser.LIKE -> { try { - yield new WildcardLike(source, left, new WildcardPattern(pattern.fold().toString())); + yield new WildcardLike( + source, + left, + new WildcardPattern(pattern.fold(FoldContext.small() /* TODO remove me */).toString()) + ); } catch (InvalidArgumentException e) { throw new ParsingException(source, "Invalid pattern for LIKE [{}]: [{}]", pattern, e.getMessage()); } } - case EsqlBaseParser.RLIKE -> new RLike(source, left, new RLikePattern(pattern.fold().toString())); + case EsqlBaseParser.RLIKE -> new RLike( + source, + left, + new RLikePattern(pattern.fold(FoldContext.small() /* TODO remove me */).toString()) + ); default -> throw new ParsingException("Invalid predicate type for [{}]", source.text()); }; return ctx.NOT() == null ? result : new Not(source, result); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index 49d77bc36fb2e..4edd0470058db 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; @@ -157,7 +158,7 @@ public PlanFactory visitEvalCommand(EsqlBaseParser.EvalCommandContext ctx) { public PlanFactory visitGrokCommand(EsqlBaseParser.GrokCommandContext ctx) { return p -> { Source source = source(ctx); - String pattern = visitString(ctx.string()).fold().toString(); + String pattern = visitString(ctx.string()).fold(FoldContext.small() /* TODO remove me */).toString(); Grok.Parser grokParser; try { grokParser = Grok.pattern(source, pattern); @@ -188,7 +189,7 @@ private void validateGrokPattern(Source source, Grok.Parser grokParser, String p @Override public PlanFactory visitDissectCommand(EsqlBaseParser.DissectCommandContext ctx) { return p -> { - String pattern = visitString(ctx.string()).fold().toString(); + String pattern = visitString(ctx.string()).fold(FoldContext.small() /* TODO remove me */).toString(); Map options = visitCommandOptions(ctx.commandOptions()); String appendSeparator = ""; for (Map.Entry item : options.entrySet()) { @@ -243,7 +244,7 @@ public Map visitCommandOptions(EsqlBaseParser.CommandOptionsCont } Map result = new HashMap<>(); for (EsqlBaseParser.CommandOptionContext option : ctx.commandOption()) { - result.put(visitIdentifier(option.identifier()), expression(option.constant()).fold()); + result.put(visitIdentifier(option.identifier()), expression(option.constant()).fold(FoldContext.small() /* TODO remove me */)); } return result; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java index 6755a7fa30af9..9b81060349815 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; @@ -149,7 +150,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(policyName()); out.writeNamedWriteable(matchField()); if (out.getTransportVersion().before(TransportVersions.V_8_13_0)) { - out.writeString(BytesRefs.toString(policyName().fold())); // old policy name + out.writeString(BytesRefs.toString(policyName().fold(FoldContext.small() /* TODO remove me */))); // old policy name } policy().writeTo(out); if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_13_0)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index 57ba1c8016feb..072bae21da2a3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; @@ -47,9 +48,11 @@ public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders { private final AggregateMapper aggregateMapper = new AggregateMapper(); + private final FoldContext foldContext; private final AnalysisRegistry analysisRegistry; - AbstractPhysicalOperationProviders(AnalysisRegistry analysisRegistry) { + AbstractPhysicalOperationProviders(FoldContext foldContext, AnalysisRegistry analysisRegistry) { + this.foldContext = foldContext; this.analysisRegistry = analysisRegistry; } @@ -251,6 +254,7 @@ public static List intermediateAttributes(List aggregates, AggregatorMode mode, Layout layout, @@ -311,7 +315,11 @@ else if (mode == AggregatorMode.FINAL || mode == AggregatorMode.INTERMEDIATE) { // apply the filter only in the initial phase - as the rest of the data is already filtered if (aggregateFunction.hasFilter() && mode.isInputPartial() == false) { - EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator(aggregateFunction.filter(), layout); + EvalOperator.ExpressionEvaluator.Factory evalFactory = EvalMapper.toEvaluator( + foldContext, + aggregateFunction.filter(), + layout + ); aggSupplier = new FilteredAggregatorFunctionSupplier(aggSupplier, evalFactory); } consumer.accept(new AggFunctionSupplierContext(aggSupplier, mode)); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 11a599a15662f..eb3d09414fcdd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -47,6 +47,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; @@ -94,8 +95,8 @@ public interface ShardContext extends org.elasticsearch.compute.lucene.ShardCont private final List shardContexts; - public EsPhysicalOperationProviders(List shardContexts, AnalysisRegistry analysisRegistry) { - super(analysisRegistry); + public EsPhysicalOperationProviders(FoldContext foldContext, List shardContexts, AnalysisRegistry analysisRegistry) { + super(foldContext, analysisRegistry); this.shardContexts = shardContexts; } @@ -161,7 +162,7 @@ public final PhysicalOperation sourcePhysicalOperation(EsQueryExec esQueryExec, List sorts = esQueryExec.sorts(); assert esQueryExec.estimatedRowSize() != null : "estimated row size not initialized"; int rowEstimatedSize = esQueryExec.estimatedRowSize(); - int limit = esQueryExec.limit() != null ? (Integer) esQueryExec.limit().fold() : NO_LIMIT; + int limit = esQueryExec.limit() != null ? (Integer) esQueryExec.limit().fold(context.foldCtx()) : NO_LIMIT; boolean scoring = esQueryExec.attrs() .stream() .anyMatch(a -> a instanceof MetadataAttribute && a.name().equals(MetadataAttribute.SCORE)); @@ -217,7 +218,7 @@ public LuceneCountOperator.Factory countSource(LocalExecutionPlannerContext cont querySupplier(queryBuilder), context.queryPragmas().dataPartitioning(), context.queryPragmas().taskConcurrency(), - limit == null ? NO_LIMIT : (Integer) limit.fold() + limit == null ? NO_LIMIT : (Integer) limit.fold(context.foldCtx()) ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java index a1765977ee9c2..c185bd5729879 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsqlExpressionTranslators.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.TranslationAware; import org.elasticsearch.xpack.esql.core.expression.TypedAttribute; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; @@ -144,7 +145,7 @@ public static void checkInsensitiveComparison(InsensitiveEquals bc) { static Query translate(InsensitiveEquals bc) { TypedAttribute attribute = checkIsPushableAttribute(bc.left()); Source source = bc.source(); - BytesRef value = BytesRefs.toBytesRef(valueOf(bc.right())); + BytesRef value = BytesRefs.toBytesRef(valueOf(FoldContext.small() /* TODO remove me */, bc.right())); String name = pushableAttributeName(attribute); return new TermQuery(source, name, value.utf8ToString(), true); } @@ -188,7 +189,7 @@ static Query translate(BinaryComparison bc, TranslatorHandler handler) { TypedAttribute attribute = checkIsPushableAttribute(bc.left()); Source source = bc.source(); String name = handler.nameOf(attribute); - Object result = bc.right().fold(); + Object result = bc.right().fold(FoldContext.small() /* TODO remove me */); Object value = result; String format = null; boolean isDateLiteralComparison = false; @@ -269,7 +270,7 @@ private static Query translateOutOfRangeComparisons(BinaryComparison bc) { return null; } Source source = bc.source(); - Object value = valueOf(bc.right()); + Object value = valueOf(FoldContext.small() /* TODO remove me */, bc.right()); // Comparisons with multi-values always return null in ESQL. if (value instanceof List) { @@ -369,7 +370,7 @@ public static Query doTranslate(ScalarFunction f, TranslatorHandler handler) { if (f instanceof CIDRMatch cm) { if (cm.ipField() instanceof FieldAttribute fa && Expressions.foldable(cm.matches())) { String targetFieldName = handler.nameOf(fa.exactAttribute()); - Set set = new LinkedHashSet<>(Expressions.fold(cm.matches())); + Set set = new LinkedHashSet<>(Expressions.fold(FoldContext.small() /* TODO remove me */, cm.matches())); Query query = new TermsQuery(f.source(), targetFieldName, set); // CIDR_MATCH applies only to single values. @@ -420,7 +421,7 @@ static Query translate( String name = handler.nameOf(attribute); try { - Geometry shape = SpatialRelatesUtils.makeGeometryFromLiteral(constantExpression); + Geometry shape = SpatialRelatesUtils.makeGeometryFromLiteral(FoldContext.small() /* TODO remove me */, constantExpression); return new SpatialRelatesQuery(bc.source(), name, bc.queryRelation(), shape, attribute.dataType()); } catch (IllegalArgumentException e) { throw new QlIllegalArgumentException(e.getMessage(), e); @@ -461,7 +462,7 @@ private static Query translate(In in, TranslatorHandler handler) { queries.add(query); } } else { - terms.add(valueOf(rhs)); + terms.add(valueOf(FoldContext.small() /* TODO remove me */, rhs)); } } } @@ -487,8 +488,8 @@ public static Query doTranslate(Range r, TranslatorHandler handler) { } private static RangeQuery translate(Range r, TranslatorHandler handler) { - Object lower = valueOf(r.lower()); - Object upper = valueOf(r.upper()); + Object lower = valueOf(FoldContext.small() /* TODO remove me */, r.lower()); + Object upper = valueOf(FoldContext.small() /* TODO remove me */, r.upper()); String format = null; DataType dataType = r.value().dataType(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 2275cf875161e..02f6a5f9d30a7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -22,7 +22,6 @@ import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.EvalOperatorFactory; -import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; import org.elasticsearch.compute.operator.FilterOperator.FilterOperatorFactory; import org.elasticsearch.compute.operator.LocalSourceOperator; import org.elasticsearch.compute.operator.LocalSourceOperator.LocalSourceFactory; @@ -55,6 +54,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; @@ -161,13 +161,14 @@ public LocalExecutionPlanner( /** * turn the given plan into a list of drivers to execute */ - public LocalExecutionPlan plan(PhysicalPlan localPhysicalPlan) { + public LocalExecutionPlan plan(FoldContext foldCtx, PhysicalPlan localPhysicalPlan) { var context = new LocalExecutionPlannerContext( new ArrayList<>(), new Holder<>(DriverParallelism.SINGLE), configuration.pragmas(), bigArrays, blockFactory, + foldCtx, settings ); @@ -397,7 +398,7 @@ private PhysicalOperation planEval(EvalExec eval, LocalExecutionPlannerContext c PhysicalOperation source = plan(eval.child(), context); for (Alias field : eval.fields()) { - var evaluatorSupplier = EvalMapper.toEvaluator(field.child(), source.layout); + var evaluatorSupplier = EvalMapper.toEvaluator(context.foldCtx(), field.child(), source.layout); Layout.Builder layout = source.layout.builder(); layout.append(field.toAttribute()); source = source.with(new EvalOperatorFactory(evaluatorSupplier), layout.build()); @@ -418,7 +419,7 @@ private PhysicalOperation planDissect(DissectExec dissect, LocalExecutionPlanner source = source.with( new StringExtractOperator.StringExtractOperatorFactory( patternNames, - EvalMapper.toEvaluator(expr, layout), + EvalMapper.toEvaluator(context.foldCtx(), expr, layout), () -> (input) -> dissect.parser().parser().parse(input) ), layout @@ -450,7 +451,7 @@ private PhysicalOperation planGrok(GrokExec grok, LocalExecutionPlannerContext c source = source.with( new ColumnExtractOperator.Factory( types, - EvalMapper.toEvaluator(grok.inputExpression(), layout), + EvalMapper.toEvaluator(context.foldCtx(), grok.inputExpression(), layout), () -> new GrokEvaluatorExtracter(grok.pattern().grok(), grok.pattern().pattern(), fieldToPos, fieldToType) ), layout @@ -599,10 +600,6 @@ private PhysicalOperation planLookupJoin(LookupJoinExec join, LocalExecutionPlan ); } - private ExpressionEvaluator.Factory toEvaluator(Expression exp, Layout layout) { - return EvalMapper.toEvaluator(exp, layout); - } - private PhysicalOperation planLocal(LocalSourceExec localSourceExec, LocalExecutionPlannerContext context) { Layout.Builder layout = new Layout.Builder(); layout.append(localSourceExec.output()); @@ -657,12 +654,15 @@ private PhysicalOperation planProject(ProjectExec project, LocalExecutionPlanner private PhysicalOperation planFilter(FilterExec filter, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(filter.child(), context); // TODO: should this be extracted into a separate eval block? - return source.with(new FilterOperatorFactory(toEvaluator(filter.condition(), source.layout)), source.layout); + return source.with( + new FilterOperatorFactory(EvalMapper.toEvaluator(context.foldCtx(), filter.condition(), source.layout)), + source.layout + ); } private PhysicalOperation planLimit(LimitExec limit, LocalExecutionPlannerContext context) { PhysicalOperation source = plan(limit.child(), context); - return source.with(new Factory((Integer) limit.limit().fold()), source.layout); + return source.with(new Factory((Integer) limit.limit().fold(context.foldCtx)), source.layout); } private PhysicalOperation planMvExpand(MvExpandExec mvExpandExec, LocalExecutionPlannerContext context) { @@ -783,6 +783,7 @@ public record LocalExecutionPlannerContext( QueryPragmas queryPragmas, BigArrays bigArrays, BlockFactory blockFactory, + FoldContext foldCtx, Settings settings ) { void addDriverFactory(DriverFactory driverFactory) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java index d5afccb4a554b..437477522e5c1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -159,13 +160,18 @@ private static , E extends T> void forEachUpWithChildren( } } - public static PhysicalPlan localPlan(List searchContexts, Configuration configuration, PhysicalPlan plan) { - return localPlan(configuration, plan, SearchContextStats.from(searchContexts)); + public static PhysicalPlan localPlan( + List searchContexts, + Configuration configuration, + FoldContext foldCtx, + PhysicalPlan plan + ) { + return localPlan(configuration, foldCtx, plan, SearchContextStats.from(searchContexts)); } - public static PhysicalPlan localPlan(Configuration configuration, PhysicalPlan plan, SearchStats searchStats) { - final var logicalOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats)); - var physicalOptimizer = new LocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, searchStats)); + public static PhysicalPlan localPlan(Configuration configuration, FoldContext foldCtx, PhysicalPlan plan, SearchStats searchStats) { + final var logicalOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, foldCtx, searchStats)); + var physicalOptimizer = new LocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, foldCtx, searchStats)); return localPlan(plan, logicalOptimizer, physicalOptimizer); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java index 334875927eb96..4dea8a50b5c17 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/TypeConverter.java @@ -14,6 +14,9 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; class TypeConverter { @@ -33,19 +36,26 @@ public static TypeConverter fromConvertFunction(AbstractConvertFunction convertF BigArrays.NON_RECYCLING_INSTANCE ) ); - return new TypeConverter( - convertFunction.functionName(), - convertFunction.toEvaluator(e -> driverContext -> new ExpressionEvaluator() { - @Override - public org.elasticsearch.compute.data.Block eval(Page page) { - // This is a pass-through evaluator, since it sits directly on the source loading (no prior expressions) - return page.getBlock(0); - } - - @Override - public void close() {} - }).get(driverContext1) - ); + return new TypeConverter(convertFunction.functionName(), convertFunction.toEvaluator(new EvaluatorMapper.ToEvaluator() { + @Override + public ExpressionEvaluator.Factory apply(Expression expression) { + return driverContext -> new ExpressionEvaluator() { + @Override + public org.elasticsearch.compute.data.Block eval(Page page) { + // This is a pass-through evaluator, since it sits directly on the source loading (no prior expressions) + return page.getBlock(0); + } + + @Override + public void close() {} + }; + } + + @Override + public FoldContext foldCtx() { + throw new IllegalStateException("not folding"); + } + }).get(driverContext1)); } public Block convert(Block block) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java index e881eabb38c43..b8f539ea307c9 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java @@ -11,6 +11,7 @@ import org.elasticsearch.compute.aggregation.AggregatorMode; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -90,7 +91,7 @@ static PhysicalPlan mapUnary(UnaryPlan p, PhysicalPlan child) { enrich.mode(), enrich.policy().getType(), enrich.matchField(), - BytesRefs.toString(enrich.policyName().fold()), + BytesRefs.toString(enrich.policyName().fold(FoldContext.small() /* TODO remove me */)), enrich.policy().getMatchField(), enrich.concreteIndices(), enrich.enrichFields() diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 7223e6988bb19..a38236fe60954 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -61,6 +61,7 @@ import org.elasticsearch.xpack.esql.action.EsqlQueryAction; import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService; import org.elasticsearch.xpack.esql.plan.physical.ExchangeSinkExec; @@ -140,6 +141,7 @@ public void execute( CancellableTask rootTask, PhysicalPlan physicalPlan, Configuration configuration, + FoldContext foldContext, EsqlExecutionInfo execInfo, ActionListener listener ) { @@ -174,6 +176,7 @@ public void execute( RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, List.of(), configuration, + foldContext, null, null ); @@ -226,6 +229,7 @@ public void execute( RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, List.of(), configuration, + foldContext, exchangeSource, null ), @@ -460,16 +464,16 @@ public SourceProvider createSourceProvider() { context.exchangeSink(), enrichLookupService, lookupFromIndexService, - new EsPhysicalOperationProviders(contexts, searchService.getIndicesService().getAnalysis()) + new EsPhysicalOperationProviders(context.foldCtx(), contexts, searchService.getIndicesService().getAnalysis()) ); LOGGER.debug("Received physical plan:\n{}", plan); - plan = PlannerUtils.localPlan(context.searchExecutionContexts(), context.configuration, plan); + plan = PlannerUtils.localPlan(context.searchExecutionContexts(), context.configuration, context.foldCtx(), plan); // the planner will also set the driver parallelism in LocalExecutionPlanner.LocalExecutionPlan (used down below) // it's doing this in the planning of EsQueryExec (the source of the data) // see also EsPhysicalOperationProviders.sourcePhysicalOperation - LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(plan); + LocalExecutionPlanner.LocalExecutionPlan localExecutionPlan = planner.plan(context.foldCtx(), plan); if (LOGGER.isDebugEnabled()) { LOGGER.debug("Local execution plan:\n{}", localExecutionPlan.describe()); } @@ -715,7 +719,15 @@ public void onFailure(Exception e) { }; acquireSearchContexts(clusterAlias, shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> { assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH, ESQL_WORKER_THREAD_POOL_NAME); - var computeContext = new ComputeContext(sessionId, clusterAlias, searchContexts, configuration, null, exchangeSink); + var computeContext = new ComputeContext( + sessionId, + clusterAlias, + searchContexts, + configuration, + configuration.newFoldContext(), + null, + exchangeSink + ); runCompute(parentTask, computeContext, request.plan(), batchListener); }, batchListener::onFailure)); } @@ -766,6 +778,7 @@ private void runComputeOnDataNode( request.clusterAlias(), List.of(), request.configuration(), + new FoldContext(request.pragmas().foldLimit().getBytes()), exchangeSource, externalSink ), @@ -901,7 +914,15 @@ void runComputeOnRemoteCluster( exchangeSink.addCompletionListener(computeListener.acquireAvoid()); runCompute( parentTask, - new ComputeContext(localSessionId, clusterAlias, List.of(), configuration, exchangeSource, exchangeSink), + new ComputeContext( + localSessionId, + clusterAlias, + List.of(), + configuration, + configuration.newFoldContext(), + exchangeSource, + exchangeSink + ), coordinatorPlan, computeListener.acquireCompute(clusterAlias) ); @@ -925,6 +946,7 @@ record ComputeContext( String clusterAlias, List searchContexts, Configuration configuration, + FoldContext foldCtx, ExchangeSourceHandler exchangeSource, ExchangeSinkHandler exchangeSink ) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java index 58e80e569ee5e..2443c3f2cda62 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java @@ -12,12 +12,14 @@ import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.compute.operator.Driver; import org.elasticsearch.compute.operator.DriverStatus; import org.elasticsearch.core.TimeValue; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.esql.core.expression.Expression; import java.io.IOException; import java.util.Objects; @@ -57,6 +59,8 @@ public final class QueryPragmas implements Writeable { public static final Setting NODE_LEVEL_REDUCTION = Setting.boolSetting("node_level_reduction", true); + public static final Setting FOLD_LIMIT = Setting.memorySizeSetting("fold_limit", "5%"); + public static final QueryPragmas EMPTY = new QueryPragmas(Settings.EMPTY); private final Settings settings; @@ -134,6 +138,17 @@ public boolean nodeLevelReduction() { return NODE_LEVEL_REDUCTION.get(settings); } + /** + * The maximum amount of memory we can use for {@link Expression#fold} during planing. This + * defaults to 5% of memory available on the current node. If this method is called on the + * coordinating node, this is 5% of the coordinating node's memory. If it's called on a data + * node, it's 5% of the data node. That's an exciting inconsistency. But it's + * important. Bigger nodes have more space to do folding. + */ + public ByteSizeValue foldLimit() { + return FOLD_LIMIT.get(settings); + } + public boolean isEmpty() { return settings.isEmpty(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index b44e249e38006..84173eeecc060 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -43,6 +43,7 @@ import org.elasticsearch.xpack.esql.action.EsqlQueryResponse; import org.elasticsearch.xpack.esql.action.EsqlQueryTask; import org.elasticsearch.xpack.esql.core.async.AsyncTaskManagementService; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver; import org.elasticsearch.xpack.esql.enrich.LookupFromIndexService; @@ -189,11 +190,13 @@ private void innerExecute(Task task, EsqlQueryRequest request, ActionListener computeService.execute( sessionId, (CancellableTask) task, plan, configuration, + foldCtx, executionInfo, resultListener ); @@ -201,6 +204,7 @@ private void innerExecute(Task task, EsqlQueryRequest request, ActionListener new EnrichPolicyResolver.UnresolvedPolicy((String) e.policyName().fold(), e.mode())) + .map( + e -> new EnrichPolicyResolver.UnresolvedPolicy( + (String) e.policyName().fold(FoldContext.small() /* TODO remove me*/), + e.mode() + ) + ) .collect(Collectors.toSet()); final List indices = preAnalysis.indices; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java index c6cff0dce7bf9..b193f2e5ad666 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/type/EsqlDataTypeConverter.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.core.QlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.Converter; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -241,9 +242,9 @@ public static Converter converterFor(DataType from, DataType to) { return null; } - public static TemporalAmount foldToTemporalAmount(Expression field, String sourceText, DataType expectedType) { + public static TemporalAmount foldToTemporalAmount(FoldContext ctx, Expression field, String sourceText, DataType expectedType) { if (field.foldable()) { - Object v = field.fold(); + Object v = field.fold(ctx); if (v instanceof BytesRef b) { try { return EsqlDataTypeConverter.parseTemporalAmount(b.utf8ToString(), expectedType); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index ceb9128b65d2d..5f4671aba2cd3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -55,6 +55,7 @@ import org.elasticsearch.xpack.esql.analysis.EnrichResolution; import org.elasticsearch.xpack.esql.analysis.PreAnalyzer; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.core.type.InvalidMappedField; @@ -479,15 +480,16 @@ private static CsvTestsDataLoader.MultiIndexTestDataset testDatasets(LogicalPlan return new CsvTestsDataLoader.MultiIndexTestDataset(indexName, datasets); } - private static TestPhysicalOperationProviders testOperationProviders(CsvTestsDataLoader.MultiIndexTestDataset datasets) - throws Exception { - var indexResolution = loadIndexResolution(datasets); + private static TestPhysicalOperationProviders testOperationProviders( + FoldContext foldCtx, + CsvTestsDataLoader.MultiIndexTestDataset datasets + ) throws Exception { var indexPages = new ArrayList(); for (CsvTestsDataLoader.TestDataset dataset : datasets.datasets()) { var testData = loadPageFromCsv(CsvTests.class.getResource("/data/" + dataset.dataFileName()), dataset.typeMapping()); indexPages.add(new TestPhysicalOperationProviders.IndexPage(dataset.indexName(), testData.v1(), testData.v2())); } - return TestPhysicalOperationProviders.create(indexPages); + return TestPhysicalOperationProviders.create(foldCtx, indexPages); } private ActualResults executePlan(BigArrays bigArrays) throws Exception { @@ -495,6 +497,7 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { var testDatasets = testDatasets(parsed); LogicalPlan analyzed = analyzedPlan(parsed, testDatasets); + FoldContext foldCtx = FoldContext.small(); EsqlSession session = new EsqlSession( getTestName(), configuration, @@ -502,21 +505,21 @@ private ActualResults executePlan(BigArrays bigArrays) throws Exception { null, null, functionRegistry, - new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)), + new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration, foldCtx)), mapper, TEST_VERIFIER, new PlanningMetrics(), null, EsqlTestUtils.MOCK_QUERY_BUILDER_RESOLVER ); - TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(testDatasets); + TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(foldCtx, testDatasets); PlainActionFuture listener = new PlainActionFuture<>(); session.executeOptimizedPlan( new EsqlQueryRequest(), new EsqlExecutionInfo(randomBoolean()), - planRunner(bigArrays, physicalOperationProviders), + planRunner(bigArrays, foldCtx, physicalOperationProviders), session.optimizedPlan(analyzed), listener.delegateFailureAndWrap( // Wrap so we can capture the warnings in the calling thread @@ -576,12 +579,13 @@ private void assertWarnings(List warnings) { testCase.assertWarnings(false).assertWarnings(normalized); } - PlanRunner planRunner(BigArrays bigArrays, TestPhysicalOperationProviders physicalOperationProviders) { - return (physicalPlan, listener) -> executeSubPlan(bigArrays, physicalOperationProviders, physicalPlan, listener); + PlanRunner planRunner(BigArrays bigArrays, FoldContext foldCtx, TestPhysicalOperationProviders physicalOperationProviders) { + return (physicalPlan, listener) -> executeSubPlan(bigArrays, foldCtx, physicalOperationProviders, physicalPlan, listener); } void executeSubPlan( BigArrays bigArrays, + FoldContext foldCtx, TestPhysicalOperationProviders physicalOperationProviders, PhysicalPlan physicalPlan, ActionListener listener @@ -627,12 +631,17 @@ void executeSubPlan( // replace fragment inside the coordinator plan List drivers = new ArrayList<>(); - LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan(new OutputExec(coordinatorPlan, collectedPages::add)); + LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan( + foldCtx, + new OutputExec(coordinatorPlan, collectedPages::add) + ); drivers.addAll(coordinatorNodeExecutionPlan.createDrivers(getTestName())); if (dataNodePlan != null) { var searchStats = new DisabledSearchStats(); - var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats)); - var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, searchStats)); + var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, foldCtx, searchStats)); + var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer( + new LocalPhysicalOptimizerContext(configuration, foldCtx, searchStats) + ); var csvDataNodePhysicalPlan = PlannerUtils.localPlan(dataNodePlan, logicalTestOptimizer, physicalTestOptimizer); exchangeSource.addRemoteSink( @@ -643,7 +652,7 @@ void executeSubPlan( throw new AssertionError("expected no failure", e); }) ); - LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(csvDataNodePhysicalPlan); + LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(foldCtx, csvDataNodePhysicalPlan); drivers.addAll(dataNodeExecutionPlan.createDrivers(getTestName())); Randomness.shuffle(drivers); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index f9d1890c13a1a..91f8704204863 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -1084,7 +1084,7 @@ public void testImplicitLimit() { from test """); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(DEFAULT_LIMIT)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT)); as(limit.child(), EsRelation.class); } @@ -1092,7 +1092,7 @@ public void testImplicitMaxLimitAfterLimit() { for (int i = -1; i <= 1; i++) { var plan = analyze("from test | limit " + (MAX_LIMIT + i)); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(MAX_LIMIT)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT)); limit = as(limit.child(), Limit.class); as(limit.child(), EsRelation.class); } @@ -1109,7 +1109,7 @@ public void testImplicitMaxLimitAfterLimitAndNonLimit() { for (int i = -1; i <= 1; i++) { var plan = analyze("from test | limit " + (MAX_LIMIT + i) + " | eval s = salary * 10 | where s > 0"); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(MAX_LIMIT)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT)); var filter = as(limit.child(), Filter.class); var eval = as(filter.child(), Eval.class); limit = as(eval.child(), Limit.class); @@ -1121,7 +1121,7 @@ public void testImplicitDefaultLimitAfterLimitAndBreaker() { for (var breaker : List.of("stats c = count(salary) by last_name", "sort salary")) { var plan = analyze("from test | limit 100000 | " + breaker); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(MAX_LIMIT)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(MAX_LIMIT)); } } @@ -1129,7 +1129,7 @@ public void testImplicitDefaultLimitAfterBreakerAndNonBreakers() { for (var breaker : List.of("stats c = count(salary) by last_name", "eval c = salary | sort c")) { var plan = analyze("from test | " + breaker + " | eval cc = c * 10 | where cc > 0"); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(DEFAULT_LIMIT)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(DEFAULT_LIMIT)); } } @@ -1435,7 +1435,7 @@ public void testEmptyEsRelationOnLimitZeroWithCount() throws IOException { var plan = analyzeWithEmptyFieldCapsResponse(query); var limit = as(plan, Limit.class); limit = as(limit.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(0)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(0)); var orderBy = as(limit.child(), OrderBy.class); var agg = as(orderBy.child(), Aggregate.class); assertEmptyEsRelation(agg.child()); @@ -1450,7 +1450,7 @@ public void testEmptyEsRelationOnConstantEvalAndKeep() throws IOException { var plan = analyzeWithEmptyFieldCapsResponse(query); var limit = as(plan, Limit.class); limit = as(limit.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(2)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(2)); var project = as(limit.child(), EsqlProject.class); var eval = as(project.child(), Eval.class); assertEmptyEsRelation(eval.child()); @@ -1467,7 +1467,7 @@ public void testEmptyEsRelationOnConstantEvalAndStats() throws IOException { var agg = as(limit.child(), Aggregate.class); var eval = as(agg.child(), Eval.class); limit = as(eval.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(10)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(10)); assertEmptyEsRelation(limit.child()); } @@ -2061,10 +2061,10 @@ public void testLookup() { } LogicalPlan plan = analyze(query); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(1000)); var lookup = as(limit.child(), Lookup.class); - assertThat(lookup.tableName().fold(), equalTo("int_number_names")); + assertThat(as(lookup.tableName(), Literal.class).value(), equalTo("int_number_names")); assertMap(lookup.matchFields().stream().map(Object::toString).toList(), matchesList().item(startsWith("int{r}"))); assertThat( lookup.localRelation().output().stream().map(Object::toString).toList(), @@ -2343,7 +2343,7 @@ public void testCoalesceWithMixedNumericTypes() { projection = as(projections.get(3), ReferenceAttribute.class); assertEquals(projection.name(), "w"); assertEquals(projection.dataType(), DataType.DOUBLE); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(as(limit.limit(), Literal.class).value(), equalTo(1000)); } public void testNamedParamsForIdentifiers() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapperTests.java new file mode 100644 index 0000000000000..828f9e061686b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/evaluator/mapper/EvaluatorMapperTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.evaluator.mapper; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.hamcrest.Matchers; + +public class EvaluatorMapperTests extends ESTestCase { + public void testFoldCompletesWithPlentyOfMemory() { + Add add = new Add( + Source.synthetic("shouldn't break"), + new Literal(Source.EMPTY, 1, DataType.INTEGER), + new Literal(Source.EMPTY, 3, DataType.INTEGER) + ); + assertEquals(add.fold(new FoldContext(100)), 4); + } + + public void testFoldBreaksWithLittleMemory() { + Add add = new Add( + Source.synthetic("should break"), + new Literal(Source.EMPTY, 1, DataType.INTEGER), + new Literal(Source.EMPTY, 3, DataType.INTEGER) + ); + Exception e = expectThrows(FoldContext.FoldTooMuchMemoryException.class, () -> add.fold(new FoldContext(10))); + assertThat( + e.getMessage(), + Matchers.equalTo( + "line -1:-1: Folding query used more than 10b. " + + "The expression that pushed past the limit is [should break] which needed 32b." + ) + ); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index c086245d6fd61..87ea6315d4f3b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -22,6 +22,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.NumericUtils; @@ -40,6 +41,7 @@ import java.util.stream.IntStream; import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; @@ -263,12 +265,12 @@ private void evaluate(Expression evaluableExpression) { assertTrue(evaluableExpression.foldable()); if (testCase.foldingExceptionClass() != null) { - Throwable t = expectThrows(testCase.foldingExceptionClass(), evaluableExpression::fold); + Throwable t = expectThrows(testCase.foldingExceptionClass(), () -> evaluableExpression.fold(FoldContext.small())); assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage())); return; } - Object result = evaluableExpression.fold(); + Object result = evaluableExpression.fold(FoldContext.small()); // Decode unsigned longs into BigIntegers if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) { result = NumericUtils.unsignedLongAsBigInteger((Long) result); @@ -289,7 +291,7 @@ private void resolveExpression(Expression expression, Consumer onAgg expression = resolveSurrogates(expression); // As expressions may be composed of multiple functions, we need to fold nulls bottom-up - expression = expression.transformUp(e -> new FoldNull().rule(e)); + expression = expression.transformUp(e -> new FoldNull().rule(e, unboundLogicalOptimizerContext())); assertThat(expression.dataType(), equalTo(testCase.expectedType())); Expression.TypeResolution resolution = expression.typeResolved(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index 9dae485d09cf9..27ac2331943a4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; @@ -44,6 +45,7 @@ import org.elasticsearch.xpack.esql.core.util.NumericUtils; import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.function.fulltext.Match; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Greatest; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; @@ -97,6 +99,7 @@ import static java.util.Map.entry; import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.endsWith; @@ -530,14 +533,28 @@ protected final Expression buildLiteralExpression(TestCaseSupplier.TestCase test return build(testCase.getSource(), testCase.getDataAsLiterals()); } + public static EvaluatorMapper.ToEvaluator toEvaluator() { + return new EvaluatorMapper.ToEvaluator() { + @Override + public ExpressionEvaluator.Factory apply(Expression expression) { + return evaluator(expression); + } + + @Override + public FoldContext foldCtx() { + return FoldContext.small(); + } + }; + } + /** * Convert an {@link Expression} tree into a {@link ExpressionEvaluator.Factory} * for {@link ExpressionEvaluator}s in the same way as our planner. */ public static ExpressionEvaluator.Factory evaluator(Expression e) { - e = new FoldNull().rule(e); + e = new FoldNull().rule(e, unboundLogicalOptimizerContext()); if (e.foldable()) { - e = new Literal(e.source(), e.fold(), e.dataType()); + e = new Literal(e.source(), e.fold(FoldContext.small()), e.dataType()); } Layout.Builder builder = new Layout.Builder(); buildLayout(builder, e); @@ -545,7 +562,7 @@ public static ExpressionEvaluator.Factory evaluator(Expression e) { if (resolution.unresolved()) { throw new AssertionError("expected resolved " + resolution.message()); } - return EvalMapper.toEvaluator(e, builder.build()); + return EvalMapper.toEvaluator(FoldContext.small(), e, builder.build()); } protected final Page row(List values) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java index 65b9c447170f4..64086334b7251 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractScalarFunctionTestCase.java @@ -20,6 +20,7 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.indices.CrankyCircuitBreakerService; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.NumericUtils; import org.elasticsearch.xpack.esql.optimizer.rules.logical.FoldNull; @@ -38,6 +39,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.hamcrest.Matchers.either; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -132,7 +134,7 @@ public final void testEvaluate() { if (resolution.unresolved()) { throw new AssertionError("expected resolved " + resolution.message()); } - expression = new FoldNull().rule(expression); + expression = new FoldNull().rule(expression, unboundLogicalOptimizerContext()); assertThat(expression.dataType(), equalTo(testCase.expectedType())); logger.info("Result type: " + expression.dataType()); @@ -363,11 +365,11 @@ public void testFold() { return; } assertFalse("expected resolved", expression.typeResolved().unresolved()); - Expression nullOptimized = new FoldNull().rule(expression); + Expression nullOptimized = new FoldNull().rule(expression, unboundLogicalOptimizerContext()); assertThat(nullOptimized.dataType(), equalTo(testCase.expectedType())); assertTrue(nullOptimized.foldable()); if (testCase.foldingExceptionClass() == null) { - Object result = nullOptimized.fold(); + Object result = nullOptimized.fold(FoldContext.small()); // Decode unsigned longs into BigIntegers if (testCase.expectedType() == DataType.UNSIGNED_LONG && result != null) { result = NumericUtils.unsignedLongAsBigInteger((Long) result); @@ -380,7 +382,7 @@ public void testFold() { assertWarnings(testCase.getExpectedWarnings()); } } else { - Throwable t = expectThrows(testCase.foldingExceptionClass(), nullOptimized::fold); + Throwable t = expectThrows(testCase.foldingExceptionClass(), () -> nullOptimized.fold(FoldContext.small())); assertThat(t.getMessage(), equalTo(testCase.foldingExceptionMessage())); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java index 19af9892015b2..e507640c7b23c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/CheckLicenseTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.Verifier; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -80,7 +81,9 @@ public EsqlFunctionRegistry snapshotRegistry() { var plan = parser.createStatement(esql); plan = plan.transformDown( Limit.class, - l -> Objects.equals(l.limit().fold(), 10) ? new LicensedLimit(l.source(), l.limit(), l.child(), functionLicenseFeature) : l + l -> Objects.equals(l.limit().fold(FoldContext.small()), 10) + ? new LicensedLimit(l.source(), l.limit(), l.child(), functionLicenseFeature) + : l ); return analyzer(registry, operationMode).analyze(plan); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index f2bae0c5a4979..9bf063518d4ba 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -1412,7 +1412,7 @@ public static final class TestCase { /** * Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator} - * or {@link Expression#fold()} on the expression built by this. + * or {@link Expression#fold} on the expression built by this. */ private final String[] expectedBuildEvaluatorWarnings; @@ -1542,7 +1542,7 @@ public String[] getExpectedWarnings() { /** * Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator} - * or {@link Expression#fold()} on the expression built by this. + * or {@link Expression#fold} on the expression built by this. */ public String[] getExpectedBuildEvaluatorWarnings() { return expectedBuildEvaluatorWarnings; @@ -1624,7 +1624,7 @@ public TestCase withWarning(String warning) { /** * Warnings that are added by calling {@link AbstractFunctionTestCase#evaluator} - * or {@link Expression#fold()} on the expression built by this. + * or {@link Expression#fold} on the expression built by this. */ public TestCase withBuildEvaluatorWarning(String warning) { return new TestCase( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java index de84086e3cb4e..911878a645b42 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseExtraTests.java @@ -19,9 +19,11 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.junit.After; @@ -67,7 +69,7 @@ public void testPartialFoldDropsFirstFalse() { ); assertThat(c.foldable(), equalTo(false)); assertThat( - c.partiallyFold(), + c.partiallyFold(FoldContext.small()), equalTo(new Case(Source.synthetic("case"), field("last_cond", DataType.BOOLEAN), List.of(field("last", DataType.LONG)))) ); } @@ -80,7 +82,7 @@ public void testPartialFoldMv() { ); assertThat(c.foldable(), equalTo(false)); assertThat( - c.partiallyFold(), + c.partiallyFold(FoldContext.small()), equalTo(new Case(Source.synthetic("case"), field("last_cond", DataType.BOOLEAN), List.of(field("last", DataType.LONG)))) ); } @@ -92,7 +94,7 @@ public void testPartialFoldNoop() { List.of(field("first", DataType.LONG), field("last", DataType.LONG)) ); assertThat(c.foldable(), equalTo(false)); - assertThat(c.partiallyFold(), sameInstance(c)); + assertThat(c.partiallyFold(FoldContext.small()), sameInstance(c)); } public void testPartialFoldFirst() { @@ -102,7 +104,7 @@ public void testPartialFoldFirst() { List.of(field("first", DataType.LONG), field("last", DataType.LONG)) ); assertThat(c.foldable(), equalTo(false)); - assertThat(c.partiallyFold(), equalTo(field("first", DataType.LONG))); + assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("first", DataType.LONG))); } public void testPartialFoldFirstAfterKeepingUnknown() { @@ -118,7 +120,7 @@ public void testPartialFoldFirstAfterKeepingUnknown() { ); assertThat(c.foldable(), equalTo(false)); assertThat( - c.partiallyFold(), + c.partiallyFold(FoldContext.small()), equalTo( new Case( Source.synthetic("case"), @@ -141,7 +143,7 @@ public void testPartialFoldSecond() { ) ); assertThat(c.foldable(), equalTo(false)); - assertThat(c.partiallyFold(), equalTo(field("second", DataType.LONG))); + assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("second", DataType.LONG))); } public void testPartialFoldSecondAfterDroppingFalse() { @@ -156,7 +158,7 @@ public void testPartialFoldSecondAfterDroppingFalse() { ) ); assertThat(c.foldable(), equalTo(false)); - assertThat(c.partiallyFold(), equalTo(field("second", DataType.LONG))); + assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("second", DataType.LONG))); } public void testPartialFoldLast() { @@ -171,7 +173,7 @@ public void testPartialFoldLast() { ) ); assertThat(c.foldable(), equalTo(false)); - assertThat(c.partiallyFold(), equalTo(field("last", DataType.LONG))); + assertThat(c.partiallyFold(FoldContext.small()), equalTo(field("last", DataType.LONG))); } public void testPartialFoldLastAfterKeepingUnknown() { @@ -187,7 +189,7 @@ public void testPartialFoldLastAfterKeepingUnknown() { ); assertThat(c.foldable(), equalTo(false)); assertThat( - c.partiallyFold(), + c.partiallyFold(FoldContext.small()), equalTo( new Case( Source.synthetic("case"), @@ -203,7 +205,7 @@ public void testEvalCase() { DriverContext driverContext = driverContext(); Page page = new Page(driverContext.blockFactory().newConstantIntBlockWith(0, 1)); try ( - EvalOperator.ExpressionEvaluator eval = caseExpr.toEvaluator(AbstractFunctionTestCase::evaluator).get(driverContext); + EvalOperator.ExpressionEvaluator eval = caseExpr.toEvaluator(AbstractFunctionTestCase.toEvaluator()).get(driverContext); Block block = eval.eval(page) ) { return toJavaObject(block, 0); @@ -216,7 +218,7 @@ public void testEvalCase() { public void testFoldCase() { testCase(caseExpr -> { assertTrue(caseExpr.foldable()); - return caseExpr.fold(); + return caseExpr.fold(FoldContext.small()); }); } @@ -265,22 +267,31 @@ public void testCaseWithIncompatibleTypes() { public void testCaseIsLazy() { Case caseExpr = caseExpr(true, 1, true, 2); DriverContext driveContext = driverContext(); - EvalOperator.ExpressionEvaluator evaluator = caseExpr.toEvaluator(child -> { - Object value = child.fold(); - if (value != null && value.equals(2)) { - return dvrCtx -> new EvalOperator.ExpressionEvaluator() { - @Override - public Block eval(Page page) { - fail("Unexpected evaluation of 4th argument"); - return null; - } + EvaluatorMapper.ToEvaluator toEvaluator = new EvaluatorMapper.ToEvaluator() { + @Override + public EvalOperator.ExpressionEvaluator.Factory apply(Expression expression) { + Object value = expression.fold(FoldContext.small()); + if (value != null && value.equals(2)) { + return dvrCtx -> new EvalOperator.ExpressionEvaluator() { + @Override + public Block eval(Page page) { + fail("Unexpected evaluation of 4th argument"); + return null; + } - @Override - public void close() {} - }; + @Override + public void close() {} + }; + } + return AbstractFunctionTestCase.evaluator(expression); } - return AbstractFunctionTestCase.evaluator(child); - }).get(driveContext); + + @Override + public FoldContext foldCtx() { + return FoldContext.small(); + } + }; + EvalOperator.ExpressionEvaluator evaluator = caseExpr.toEvaluator(toEvaluator).get(driveContext); Page page = new Page(driveContext.blockFactory().newConstantIntBlockWith(0, 1)); try (Block block = evaluator.eval(page)) { assertEquals(1, toJavaObject(block, 0)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java index 05923246520fc..23a0f2307171c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.NumericUtils; @@ -779,7 +780,7 @@ public void testFancyFolding() { return; } assertThat(e.foldable(), equalTo(true)); - Object result = e.fold(); + Object result = e.fold(FoldContext.small()); if (testCase.getExpectedBuildEvaluatorWarnings() != null) { assertWarnings(testCase.getExpectedBuildEvaluatorWarnings()); } @@ -799,18 +800,18 @@ public void testPartialFold() { } Case c = (Case) buildFieldExpression(testCase); if (extra().expectedPartialFold == null) { - assertThat(c.partiallyFold(), sameInstance(c)); + assertThat(c.partiallyFold(FoldContext.small()), sameInstance(c)); return; } if (extra().expectedPartialFold.size() == 1) { - assertThat(c.partiallyFold(), equalTo(extra().expectedPartialFold.get(0).asField())); + assertThat(c.partiallyFold(FoldContext.small()), equalTo(extra().expectedPartialFold.get(0).asField())); return; } Case expected = build( Source.synthetic("expected"), extra().expectedPartialFold.stream().map(TestCaseSupplier.TypedData::asField).toList() ); - assertThat(c.partiallyFold(), equalTo(expected)); + assertThat(c.partiallyFold(FoldContext.small()), equalTo(expected)); } private static Function addWarnings(List warnings) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java index be978eda06758..cd27ce511b317 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/date/DateExtractTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -99,7 +100,7 @@ public void testAllChronoFields() { EsqlTestUtils.TEST_CFG ); - assertThat(instance.fold(), is(date.getLong(value))); + assertThat(instance.fold(FoldContext.small()), is(date.getLong(value))); assertThat( DateExtract.process(epochMilli, new BytesRef(value.name()), EsqlTestUtils.TEST_CFG.zoneId()), is(date.getLong(value)) diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java index 797c99992815e..688341ebaa2b7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/nulls/CoalesceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -187,19 +188,27 @@ public void testCoalesceIsLazy() { Layout.Builder builder = new Layout.Builder(); buildLayout(builder, exp); Layout layout = builder.build(); - EvaluatorMapper.ToEvaluator toEvaluator = child -> { - if (child == evil) { - return dvrCtx -> new EvalOperator.ExpressionEvaluator() { - @Override - public Block eval(Page page) { - throw new AssertionError("shouldn't be called"); - } - - @Override - public void close() {} - }; + EvaluatorMapper.ToEvaluator toEvaluator = new EvaluatorMapper.ToEvaluator() { + @Override + public EvalOperator.ExpressionEvaluator.Factory apply(Expression expression) { + if (expression == evil) { + return dvrCtx -> new EvalOperator.ExpressionEvaluator() { + @Override + public Block eval(Page page) { + throw new AssertionError("shouldn't be called"); + } + + @Override + public void close() {} + }; + } + return EvalMapper.toEvaluator(FoldContext.small(), expression, layout); + } + + @Override + public FoldContext foldCtx() { + return FoldContext.small(); } - return EvalMapper.toEvaluator(child, layout); }; try ( EvalOperator.ExpressionEvaluator eval = exp.toEvaluator(toEvaluator).get(driverContext()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java index 4f8adf3abaae6..6c41552a9fc52 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/RLikeTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -159,8 +160,8 @@ protected Expression build(Source source, List args) { Expression expression = args.get(0); Literal pattern = (Literal) args.get(1); Literal caseInsensitive = args.size() > 2 ? (Literal) args.get(2) : null; - String patternString = ((BytesRef) pattern.fold()).utf8ToString(); - boolean caseInsensitiveBool = caseInsensitive != null ? (boolean) caseInsensitive.fold() : false; + String patternString = ((BytesRef) pattern.fold(FoldContext.small())).utf8ToString(); + boolean caseInsensitiveBool = caseInsensitive != null ? (boolean) caseInsensitive.fold(FoldContext.small()) : false; logger.info("pattern={} caseInsensitive={}", patternString, caseInsensitiveBool); return caseInsensitiveBool diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java index b355feb6130a3..f779dd038454d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -54,7 +55,7 @@ public void testRandomLocale() { String testString = randomAlphaOfLength(10); Configuration cfg = randomLocaleConfig(); ToLower func = new ToLower(Source.EMPTY, new Literal(Source.EMPTY, testString, DataType.KEYWORD), cfg); - assertThat(BytesRefs.toBytesRef(testString.toLowerCase(cfg.locale())), equalTo(func.fold())); + assertThat(BytesRefs.toBytesRef(testString.toLowerCase(cfg.locale())), equalTo(func.fold(FoldContext.small()))); } private Configuration randomLocaleConfig() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java index fdae4f953a0fa..3957c2e1fb2c0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -54,7 +55,7 @@ public void testRandomLocale() { String testString = randomAlphaOfLength(10); Configuration cfg = randomLocaleConfig(); ToUpper func = new ToUpper(Source.EMPTY, new Literal(Source.EMPTY, testString, DataType.KEYWORD), cfg); - assertThat(BytesRefs.toBytesRef(testString.toUpperCase(cfg.locale())), equalTo(func.fold())); + assertThat(BytesRefs.toBytesRef(testString.toUpperCase(cfg.locale())), equalTo(func.fold(FoldContext.small()))); } private Configuration randomLocaleConfig() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java index eed2c7379e9e1..6626ac50d60b5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/WildcardLikeTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.tree.Source; @@ -78,8 +79,8 @@ protected Expression build(Source source, List args) { Literal pattern = (Literal) args.get(1); if (args.size() > 2) { Literal caseInsensitive = (Literal) args.get(2); - assertThat(caseInsensitive.fold(), equalTo(false)); + assertThat(caseInsensitive.fold(FoldContext.small()), equalTo(false)); } - return new WildcardLike(source, expression, new WildcardPattern(((BytesRef) pattern.fold()).utf8ToString())); + return new WildcardLike(source, expression, new WildcardPattern(((BytesRef) pattern.fold(FoldContext.small())).utf8ToString())); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java index a8c7b5b5a83fd..15860d35539e0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/NegTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.VerificationException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -154,7 +155,7 @@ public void testEdgeCases() { private Object foldTemporalAmount(Object val) { Neg neg = new Neg(Source.EMPTY, new Literal(Source.EMPTY, val, typeOf(val))); - return neg.fold(); + return neg.fold(FoldContext.small()); } private static DataType typeOf(Object val) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java index b004adca351ab..80f67ec8e5e3a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.geo.GeometryTestUtils; import org.elasticsearch.geo.ShapeTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -48,27 +49,27 @@ public InTests(@Name("TestCase") Supplier testCaseSup public void testInWithContainedValue() { In in = new In(EMPTY, TWO, Arrays.asList(ONE, TWO, THREE)); - assertTrue((Boolean) in.fold()); + assertTrue((Boolean) in.fold(FoldContext.small())); } public void testInWithNotContainedValue() { In in = new In(EMPTY, THREE, Arrays.asList(ONE, TWO)); - assertFalse((Boolean) in.fold()); + assertFalse((Boolean) in.fold(FoldContext.small())); } public void testHandleNullOnLeftValue() { In in = new In(EMPTY, NULL, Arrays.asList(ONE, TWO, THREE)); - assertNull(in.fold()); + assertNull(in.fold(FoldContext.small())); in = new In(EMPTY, NULL, Arrays.asList(ONE, NULL, THREE)); - assertNull(in.fold()); + assertNull(in.fold(FoldContext.small())); } public void testHandleNullsOnRightValue() { In in = new In(EMPTY, THREE, Arrays.asList(ONE, NULL, THREE)); - assertTrue((Boolean) in.fold()); + assertTrue((Boolean) in.fold(FoldContext.small())); in = new In(EMPTY, ONE, Arrays.asList(TWO, NULL, THREE)); - assertNull(in.fold()); + assertNull(in.fold(FoldContext.small())); } private static Literal L(Object value) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java index faf0a0d8f418c..6fa1112f23f45 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/InsensitiveEqualsTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; @@ -18,37 +19,37 @@ public class InsensitiveEqualsTests extends ESTestCase { public void testFold() { - assertTrue(insensitiveEquals(l("foo"), l("foo")).fold()); - assertTrue(insensitiveEquals(l("Foo"), l("foo")).fold()); - assertTrue(insensitiveEquals(l("Foo"), l("fOO")).fold()); - assertTrue(insensitiveEquals(l("foo*"), l("foo*")).fold()); - assertTrue(insensitiveEquals(l("foo*"), l("FOO*")).fold()); - assertTrue(insensitiveEquals(l("foo?bar"), l("foo?bar")).fold()); - assertTrue(insensitiveEquals(l("foo?bar"), l("FOO?BAR")).fold()); - assertFalse(insensitiveEquals(l("Foo"), l("fo*")).fold()); - assertFalse(insensitiveEquals(l("Fox"), l("fo?")).fold()); - assertFalse(insensitiveEquals(l("Foo"), l("*OO")).fold()); - assertFalse(insensitiveEquals(l("BarFooBaz"), l("*O*")).fold()); - assertFalse(insensitiveEquals(l("BarFooBaz"), l("bar*baz")).fold()); - assertFalse(insensitiveEquals(l("foo"), l("*")).fold()); + assertTrue(insensitiveEquals(l("foo"), l("foo")).fold(FoldContext.small())); + assertTrue(insensitiveEquals(l("Foo"), l("foo")).fold(FoldContext.small())); + assertTrue(insensitiveEquals(l("Foo"), l("fOO")).fold(FoldContext.small())); + assertTrue(insensitiveEquals(l("foo*"), l("foo*")).fold(FoldContext.small())); + assertTrue(insensitiveEquals(l("foo*"), l("FOO*")).fold(FoldContext.small())); + assertTrue(insensitiveEquals(l("foo?bar"), l("foo?bar")).fold(FoldContext.small())); + assertTrue(insensitiveEquals(l("foo?bar"), l("FOO?BAR")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("Foo"), l("fo*")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("Fox"), l("fo?")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("Foo"), l("*OO")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("BarFooBaz"), l("*O*")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("BarFooBaz"), l("bar*baz")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("foo"), l("*")).fold(FoldContext.small())); - assertFalse(insensitiveEquals(l("foo*bar"), l("foo\\*bar")).fold()); - assertFalse(insensitiveEquals(l("foo?"), l("foo\\?")).fold()); - assertFalse(insensitiveEquals(l("foo?bar"), l("foo\\?bar")).fold()); - assertFalse(insensitiveEquals(l(randomAlphaOfLength(10)), l("*")).fold()); - assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("???")).fold()); + assertFalse(insensitiveEquals(l("foo*bar"), l("foo\\*bar")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("foo?"), l("foo\\?")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("foo?bar"), l("foo\\?bar")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l(randomAlphaOfLength(10)), l("*")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("???")).fold(FoldContext.small())); - assertFalse(insensitiveEquals(l("foo"), l("bar")).fold()); - assertFalse(insensitiveEquals(l("foo"), l("ba*")).fold()); - assertFalse(insensitiveEquals(l("foo"), l("*a*")).fold()); - assertFalse(insensitiveEquals(l(""), l("bar")).fold()); - assertFalse(insensitiveEquals(l("foo"), l("")).fold()); - assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("??")).fold()); - assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("????")).fold()); + assertFalse(insensitiveEquals(l("foo"), l("bar")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("foo"), l("ba*")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("foo"), l("*a*")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l(""), l("bar")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l("foo"), l("")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("??")).fold(FoldContext.small())); + assertFalse(insensitiveEquals(l(randomAlphaOfLength(3)), l("????")).fold(FoldContext.small())); - assertNull(insensitiveEquals(l("foo"), Literal.NULL).fold()); - assertNull(insensitiveEquals(Literal.NULL, l("foo")).fold()); - assertNull(insensitiveEquals(Literal.NULL, Literal.NULL).fold()); + assertNull(insensitiveEquals(l("foo"), Literal.NULL).fold(FoldContext.small())); + assertNull(insensitiveEquals(Literal.NULL, l("foo")).fold(FoldContext.small())); + assertNull(insensitiveEquals(Literal.NULL, Literal.NULL).fold(FoldContext.small())); } public void testProcess() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index 0c03556241d28..11cd123c731e8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; @@ -70,6 +71,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForExistingField; import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.hamcrest.Matchers.contains; @@ -93,7 +95,7 @@ public static void init() { mapping = loadMapping("mapping-basic.json"); EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); IndexResolution getIndexResult = IndexResolution.valid(test); - logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext()); analyzer = new Analyzer( new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()), @@ -161,7 +163,7 @@ public void testMissingFieldInProject() { assertThat(Expressions.names(eval.fields()), contains("last_name")); var alias = as(eval.fields().get(0), Alias.class); var literal = as(alias.child(), Literal.class); - assertThat(literal.fold(), is(nullValue())); + assertThat(literal.value(), is(nullValue())); assertThat(literal.dataType(), is(DataType.KEYWORD)); var limit = as(eval.child(), Limit.class); @@ -304,7 +306,7 @@ public void testMissingFieldInEval() { var alias = as(eval.fields().get(0), Alias.class); var literal = as(alias.child(), Literal.class); - assertThat(literal.fold(), is(nullValue())); + assertThat(literal.value(), is(nullValue())); assertThat(literal.dataType(), is(DataType.INTEGER)); var limit = as(eval.child(), Limit.class); @@ -402,7 +404,7 @@ public void testSparseDocument() throws Exception { EsIndex index = new EsIndex("large", large, Map.of("large", IndexMode.STANDARD)); IndexResolution getIndexResult = IndexResolution.valid(index); - var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext()); var analyzer = new Analyzer( new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, EsqlTestUtils.emptyPolicyResolution()), @@ -411,7 +413,7 @@ public void testSparseDocument() throws Exception { var analyzed = analyzer.analyze(parser.createStatement(query)); var optimized = logicalOptimizer.optimize(analyzed); - var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, searchStats); + var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small(), searchStats); var plan = new LocalLogicalPlanOptimizer(localContext).localOptimize(optimized); var project = as(plan, Project.class); @@ -423,7 +425,7 @@ public void testSparseDocument() throws Exception { var eval = as(project.child(), Eval.class); var field = eval.fields().get(0); assertThat(Expressions.name(field), is("field005")); - assertThat(Alias.unwrap(field).fold(), Matchers.nullValue()); + assertThat(Alias.unwrap(field).fold(FoldContext.small()), Matchers.nullValue()); } // InferIsNotNull @@ -561,7 +563,7 @@ private LogicalPlan plan(String query) { } private LogicalPlan localPlan(LogicalPlan plan, SearchStats searchStats) { - var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, searchStats); + var localContext = new LocalLogicalOptimizerContext(EsqlTestUtils.TEST_CFG, FoldContext.small(), searchStats); // System.out.println(plan); var localPlan = new LocalLogicalPlanOptimizer(localContext).localOptimize(plan); // System.out.println(localPlan); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 73f1d47fb5baa..2c309af9caf5f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.esql.analysis.Verifier; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -89,6 +90,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType; import static org.hamcrest.Matchers.contains; @@ -402,7 +404,7 @@ public void testMultiCountAllWithFilter() { @SuppressWarnings("unchecked") public void testSingleCountWithStatsFilter() { // an optimizer that filters out the ExtractAggregateCommonFilter rule - var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(config)) { + var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext()) { @Override protected List> batches() { var oldBatches = super.batches(); @@ -481,7 +483,7 @@ public void testQueryStringFunction() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var expected = QueryBuilders.queryStringQuery("last_name: Smith"); assertThat(query.query().toString(), is(expected.toString())); } @@ -510,7 +512,7 @@ public void testQueryStringFunctionConjunctionWhereOperands() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(2, 37, "emp_no > 10000"); var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource); @@ -545,7 +547,7 @@ public void testQueryStringFunctionWithFunctionsPushedToLucene() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(2, 37, "cidr_match(ip, \"127.0.0.1/32\")"); var terms = wrapWithSingleQuery(queryText, QueryBuilders.termsQuery("ip", "127.0.0.1/32"), "ip", filterSource); @@ -580,7 +582,7 @@ public void testQueryStringFunctionMultipleWhereClauses() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(3, 8, "emp_no > 10000"); var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource); @@ -613,7 +615,7 @@ public void testQueryStringFunctionMultipleQstrClauses() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var queryStringLeft = QueryBuilders.queryStringQuery("last_name: Smith"); var queryStringRight = QueryBuilders.queryStringQuery("emp_no: [10010 TO *]"); @@ -642,7 +644,7 @@ public void testMatchFunction() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var expected = QueryBuilders.matchQuery("last_name", "Smith").lenient(true); assertThat(query.query().toString(), is(expected.toString())); } @@ -671,7 +673,7 @@ public void testMatchFunctionConjunctionWhereOperands() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(2, 38, "emp_no > 10000"); var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource); @@ -706,7 +708,7 @@ public void testMatchFunctionWithFunctionsPushedToLucene() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(2, 32, "cidr_match(ip, \"127.0.0.1/32\")"); var terms = wrapWithSingleQuery(queryText, QueryBuilders.termsQuery("ip", "127.0.0.1/32"), "ip", filterSource); @@ -740,7 +742,7 @@ public void testMatchFunctionMultipleWhereClauses() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(3, 8, "emp_no > 10000"); var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource); @@ -772,7 +774,7 @@ public void testMatchFunctionMultipleMatchClauses() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var queryStringLeft = QueryBuilders.matchQuery("last_name", "Smith").lenient(true); var queryStringRight = QueryBuilders.matchQuery("first_name", "John").lenient(true); @@ -801,7 +803,7 @@ public void testKqlFunction() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var expected = kqlQueryBuilder("last_name: Smith"); assertThat(query.query().toString(), is(expected.toString())); } @@ -830,7 +832,7 @@ public void testKqlFunctionConjunctionWhereOperands() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(2, 36, "emp_no > 10000"); var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource); @@ -865,7 +867,7 @@ public void testKqlFunctionWithFunctionsPushedToLucene() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(2, 36, "cidr_match(ip, \"127.0.0.1/32\")"); var terms = wrapWithSingleQuery(queryText, QueryBuilders.termsQuery("ip", "127.0.0.1/32"), "ip", filterSource); @@ -900,7 +902,7 @@ public void testKqlFunctionMultipleWhereClauses() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); Source filterSource = new Source(3, 8, "emp_no > 10000"); var range = wrapWithSingleQuery(queryText, QueryBuilders.rangeQuery("emp_no").gt(10010), "emp_no", filterSource); @@ -933,7 +935,7 @@ public void testKqlFunctionMultipleKqlClauses() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var kqlQueryLeft = kqlQueryBuilder("last_name: Smith"); var kqlQueryRight = kqlQueryBuilder("emp_no > 10010"); @@ -999,7 +1001,7 @@ public void testIsNotNullPushdownFilter() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var expected = QueryBuilders.existsQuery("emp_no"); assertThat(query.query().toString(), is(expected.toString())); } @@ -1023,7 +1025,7 @@ public void testIsNullPushdownFilter() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var expected = QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery("emp_no")); assertThat(query.query().toString(), is(expected.toString())); } @@ -1550,7 +1552,7 @@ public void testTermFunction() { var project = as(exchange.child(), ProjectExec.class); var field = as(project.child(), FieldExtractExec.class); var query = as(field.child(), EsQueryExec.class); - assertThat(query.limit().fold(), is(1000)); + assertThat(as(query.limit(), Literal.class).value(), is(1000)); var expected = QueryBuilders.termQuery("last_name", "Smith"); assertThat(query.query().toString(), is(expected.toString())); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index ecd0fd5f268f6..8b12267011f02 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; @@ -151,6 +152,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.EsqlTestUtils.localSource; import static org.elasticsearch.xpack.esql.EsqlTestUtils.referenceAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.Analyzer.NO_FIELDS; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; @@ -188,6 +190,7 @@ public class LogicalPlanOptimizerTests extends ESTestCase { private static EsqlParser parser; private static Analyzer analyzer; + private static LogicalOptimizerContext logicalOptimizerCtx; private static LogicalPlanOptimizer logicalOptimizer; private static Map mapping; private static Map mappingAirports; @@ -203,7 +206,7 @@ public class LogicalPlanOptimizerTests extends ESTestCase { private static Analyzer metricsAnalyzer; private static class SubstitutionOnlyOptimizer extends LogicalPlanOptimizer { - static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + static SubstitutionOnlyOptimizer INSTANCE = new SubstitutionOnlyOptimizer(unboundLogicalOptimizerContext()); SubstitutionOnlyOptimizer(LogicalOptimizerContext optimizerContext) { super(optimizerContext); @@ -218,7 +221,8 @@ protected List> batches() { @BeforeClass public static void init() { parser = new EsqlParser(); - logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + logicalOptimizerCtx = unboundLogicalOptimizerContext(); + logicalOptimizer = new LogicalPlanOptimizer(logicalOptimizerCtx); enrichResolution = new EnrichResolution(); AnalyzerTestUtils.loadEnrichPolicyResolution(enrichResolution, "languages_idx", "id", "languages_idx", "mapping-languages.json"); @@ -325,7 +329,7 @@ public void testEmptyProjectInStatWithEval() { assertThat(exprs.size(), equalTo(1)); var alias = as(exprs.get(0), Alias.class); assertThat(alias.name(), equalTo("x")); - assertThat(alias.child().fold(), equalTo(1)); + assertThat(alias.child().fold(FoldContext.small()), equalTo(1)); } /** @@ -361,11 +365,11 @@ public void testEmptyProjectInStatWithGroupAndEval() { assertThat(exprs.size(), equalTo(1)); var alias = as(exprs.get(0), Alias.class); assertThat(alias.name(), equalTo("x")); - assertThat(alias.child().fold(), equalTo(1)); + assertThat(alias.child().fold(FoldContext.small()), equalTo(1)); var filterCondition = as(filter.condition(), GreaterThan.class); assertThat(Expressions.name(filterCondition.left()), equalTo("languages")); - assertThat(filterCondition.right().fold(), equalTo(1)); + assertThat(filterCondition.right().fold(FoldContext.small()), equalTo(1)); } public void testCombineProjections() { @@ -625,7 +629,7 @@ public void testReplaceStatsFilteredAggWithEvalSingleAggWithExpression() { assertThat(alias.name(), is("sum(salary) + 1 where false")); var add = as(alias.child(), Add.class); var literal = as(add.right(), Literal.class); - assertThat(literal.fold(), is(1)); + assertThat(literal.value(), is(1)); var limit = as(eval.child(), Limit.class); var source = as(limit.child(), LocalRelation.class); @@ -658,7 +662,7 @@ public void testReplaceStatsFilteredAggWithEvalMixedFilterAndNoFilter() { var alias = as(eval.fields().get(0), Alias.class); assertTrue(alias.child().foldable()); - assertThat(alias.child().fold(), nullValue()); + assertThat(alias.child().fold(FoldContext.small()), nullValue()); assertThat(alias.child().dataType(), is(LONG)); alias = as(eval.fields().get(1), Alias.class); @@ -695,7 +699,7 @@ public void testReplaceStatsFilteredAggWithEvalFilterFalseAndNull() { var alias = as(eval.fields().get(0), Alias.class); assertTrue(alias.child().foldable()); - assertThat(alias.child().fold(), nullValue()); + assertThat(alias.child().fold(FoldContext.small()), nullValue()); assertThat(alias.child().dataType(), is(LONG)); alias = as(eval.fields().get(1), Alias.class); @@ -703,7 +707,7 @@ public void testReplaceStatsFilteredAggWithEvalFilterFalseAndNull() { alias = as(eval.fields().get(2), Alias.class); assertTrue(alias.child().foldable()); - assertThat(alias.child().fold(), nullValue()); + assertThat(alias.child().fold(FoldContext.small()), nullValue()); assertThat(alias.child().dataType(), is(LONG)); var limit = as(eval.child(), Limit.class); @@ -752,7 +756,7 @@ public void testReplaceStatsFilteredAggWithEvalCountDistinctInExpression() { assertThat(alias.name(), is("count_distinct(salary + 2) + 3 where false")); var add = as(alias.child(), Add.class); var literal = as(add.right(), Literal.class); - assertThat(literal.fold(), is(3)); + assertThat(literal.value(), is(3)); var limit = as(eval.child(), Limit.class); var source = as(limit.child(), LocalRelation.class); @@ -788,13 +792,13 @@ public void testReplaceStatsFilteredAggWithEvalSameAggWithAndWithoutFilter() { var alias = as(eval.fields().get(0), Alias.class); assertThat(Expressions.name(alias), containsString("max_a")); assertTrue(alias.child().foldable()); - assertThat(alias.child().fold(), nullValue()); + assertThat(alias.child().fold(FoldContext.small()), nullValue()); assertThat(alias.child().dataType(), is(INTEGER)); alias = as(eval.fields().get(1), Alias.class); assertThat(Expressions.name(alias), containsString("min_a")); assertTrue(alias.child().foldable()); - assertThat(alias.child().fold(), nullValue()); + assertThat(alias.child().fold(FoldContext.small()), nullValue()); assertThat(alias.child().dataType(), is(INTEGER)); var limit = as(eval.child(), Limit.class); @@ -933,7 +937,7 @@ public void testExtractStatsCommonFilterUsingJustOneAlias() { var gt = as(filter.condition(), GreaterThan.class); assertThat(Expressions.name(gt.left()), is("emp_no")); assertTrue(gt.right().foldable()); - assertThat(gt.right().fold(), is(1)); + assertThat(gt.right().fold(FoldContext.small()), is(1)); var source = as(filter.child(), EsRelation.class); } @@ -1053,7 +1057,7 @@ public void testExtractStatsCommonFilterInConjunction() { var gt = as(filter.condition(), GreaterThan.class); // name is "emp_no > 1 + 1" assertThat(Expressions.name(gt.left()), is("emp_no")); assertTrue(gt.right().foldable()); - assertThat(gt.right().fold(), is(2)); + assertThat(gt.right().fold(FoldContext.small()), is(2)); var source = as(filter.child(), EsRelation.class); } @@ -1083,12 +1087,12 @@ public void testExtractStatsCommonFilterInConjunctionWithMultipleCommonConjuncti var lt = as(and.left(), LessThan.class); assertThat(Expressions.name(lt.left()), is("emp_no")); assertTrue(lt.right().foldable()); - assertThat(lt.right().fold(), is(10)); + assertThat(lt.right().fold(FoldContext.small()), is(10)); var equals = as(and.right(), Equals.class); assertThat(Expressions.name(equals.left()), is("last_name")); assertTrue(equals.right().foldable()); - assertThat(equals.right().fold(), is(BytesRefs.toBytesRef("Doe"))); + assertThat(equals.right().fold(FoldContext.small()), is(BytesRefs.toBytesRef("Doe"))); var source = as(filter.child(), EsRelation.class); } @@ -1303,7 +1307,7 @@ public void testCombineLimits() { var anotherLimit = new Limit(EMPTY, L(limitValues[secondLimit]), oneLimit); assertEquals( new Limit(EMPTY, L(Math.min(limitValues[0], limitValues[1])), emptySource()), - new PushDownAndCombineLimits().rule(anotherLimit) + new PushDownAndCombineLimits().rule(anotherLimit, logicalOptimizerCtx) ); } @@ -1322,7 +1326,7 @@ public void testPushdownLimitsPastLeftJoin() { var limit = new Limit(EMPTY, L(10), join); - var optimizedPlan = new PushDownAndCombineLimits().rule(limit); + var optimizedPlan = new PushDownAndCombineLimits().rule(limit, logicalOptimizerCtx); assertEquals(join.replaceChildren(limit.replaceChild(join.left()), join.right()), optimizedPlan); } @@ -1340,10 +1344,7 @@ public void testMultipleCombineLimits() { var value = i == limitWithMinimum ? minimum : randomIntBetween(100, 1000); plan = new Limit(EMPTY, L(value), plan); } - assertEquals( - new Limit(EMPTY, L(minimum), relation), - new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)).optimize(plan) - ); + assertEquals(new Limit(EMPTY, L(minimum), relation), logicalOptimizer.optimize(plan)); } @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/115311") @@ -1864,7 +1865,7 @@ public void testCopyDefaultLimitPastMvExpand() { assertThat(mvExpand.limit(), equalTo(1000)); var keep = as(mvExpand.child(), EsqlProject.class); var limitPastMvExpand = as(keep.child(), Limit.class); - assertThat(limitPastMvExpand.limit().fold(), equalTo(1000)); + assertThat(limitPastMvExpand.limit().fold(FoldContext.small()), equalTo(1000)); as(limitPastMvExpand.child(), EsRelation.class); } @@ -1887,7 +1888,7 @@ public void testDontPushDownLimitPastMvExpand() { assertThat(mvExpand.limit(), equalTo(10)); var project = as(mvExpand.child(), EsqlProject.class); var limit = as(project.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(1)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1)); as(limit.child(), EsRelation.class); } @@ -1921,7 +1922,7 @@ public void testMultipleMvExpandWithSortAndLimit() { var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(5)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); assertThat(orderNames(topN), contains("salary")); var mvExp = as(topN.child(), MvExpand.class); assertThat(mvExp.limit(), equalTo(5)); @@ -1931,7 +1932,7 @@ public void testMultipleMvExpandWithSortAndLimit() { mvExp = as(filter.child(), MvExpand.class); assertThat(mvExp.limit(), equalTo(10)); topN = as(mvExp.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(10)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10)); filter = as(topN.child(), Filter.class); as(filter.child(), EsRelation.class); } @@ -1955,11 +1956,11 @@ public void testPushDownLimitThroughMultipleSort_AfterMvExpand() { var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(5)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); assertThat(orderNames(topN), contains("salary", "first_name")); var mvExp = as(topN.child(), MvExpand.class); topN = as(mvExp.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(10000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000)); assertThat(orderNames(topN), contains("emp_no")); as(topN.child(), EsRelation.class); } @@ -1985,14 +1986,14 @@ public void testPushDownLimitThroughMultipleSort_AfterMvExpand2() { var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(5)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); assertThat(orderNames(topN), contains("first_name")); topN = as(topN.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(5)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); assertThat(orderNames(topN), contains("salary")); var mvExp = as(topN.child(), MvExpand.class); topN = as(mvExp.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(10000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000)); assertThat(orderNames(topN), contains("emp_no")); as(topN.child(), EsRelation.class); } @@ -2021,11 +2022,11 @@ public void testDontPushDownLimitPastAggregate_AndMvExpand() { var limit = as(plan, Limit.class); var filter = as(limit.child(), Filter.class); - assertThat(limit.limit().fold(), equalTo(5)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(5)); var agg = as(filter.child(), Aggregate.class); var mvExp = as(agg.child(), MvExpand.class); var topN = as(mvExp.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(50)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(50)); assertThat(orderNames(topN), contains("emp_no")); as(topN.child(), EsRelation.class); } @@ -2052,13 +2053,13 @@ public void testPushDown_TheRightLimit_PastMvExpand() { | limit 5"""); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(5)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(5)); var filter = as(limit.child(), Filter.class); var agg = as(filter.child(), Aggregate.class); var mvExp = as(agg.child(), MvExpand.class); assertThat(mvExp.limit(), equalTo(50)); limit = as(mvExp.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(50)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(50)); as(limit.child(), EsRelation.class); } @@ -2083,12 +2084,12 @@ public void testPushDownLimit_PastEvalAndMvExpand() { var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(5)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(5)); assertThat(orderNames(topN), contains("salary")); var eval = as(topN.child(), Eval.class); var mvExp = as(eval.child(), MvExpand.class); topN = as(mvExp.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(10000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000)); assertThat(orderNames(topN), contains("first_name")); as(topN.child(), EsRelation.class); } @@ -2114,7 +2115,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedField_ResultT var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(1000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(1000)); assertThat(orderNames(topN), contains("salary", "first_name")); var filter = as(topN.child(), Filter.class); assertThat(filter.condition(), instanceOf(And.class)); @@ -2143,7 +2144,7 @@ public void testFilterWithSortBeforeMvExpand() { var mvExp = as(plan, MvExpand.class); assertThat(mvExp.limit(), equalTo(10)); var topN = as(mvExp.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(10)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10)); assertThat(orderNames(topN), contains("emp_no")); var filter = as(topN.child(), Filter.class); as(filter.child(), EsRelation.class); @@ -2168,7 +2169,7 @@ public void testMultiMvExpand_SortDownBelow() { | sort first_name"""); var topN = as(plan, TopN.class); - assertThat(topN.limit().fold(), equalTo(1000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(1000)); assertThat(orderNames(topN), contains("first_name")); var mvExpand = as(topN.child(), MvExpand.class); var filter = as(mvExpand.child(), Filter.class); @@ -2200,11 +2201,11 @@ public void testLimitThenSortBeforeMvExpand() { assertThat(mvExpand.limit(), equalTo(10000)); var project = as(mvExpand.child(), EsqlProject.class); var topN = as(project.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(7300)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(7300)); assertThat(orderNames(topN), contains("a")); mvExpand = as(topN.child(), MvExpand.class); var limit = as(mvExpand.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(7300)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(7300)); as(limit.child(), LocalRelation.class); } @@ -2224,7 +2225,7 @@ public void testRemoveUnusedSortBeforeMvExpand_DefaultLimit10000() { var topN = as(plan, TopN.class); assertThat(orderNames(topN), contains("first_name")); - assertThat(topN.limit().fold(), equalTo(10000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000)); var mvExpand = as(topN.child(), MvExpand.class); var topN2 = as(mvExpand.child(), TopN.class); // TODO is it correct? Double-check AddDefaultTopN rule as(topN2.child(), EsRelation.class); @@ -2252,7 +2253,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedField() { var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(15)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(15)); assertThat(orderNames(topN), contains("salary", "first_name")); var filter = as(topN.child(), Filter.class); assertThat(filter.condition(), instanceOf(And.class)); @@ -2260,7 +2261,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedField() { topN = as(mvExp.child(), TopN.class); // the filter acts on first_name (the one used in mv_expand), so the limit 15 is not pushed down past mv_expand // instead the default limit is added - assertThat(topN.limit().fold(), equalTo(10000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000)); assertThat(orderNames(topN), contains("emp_no")); as(topN.child(), EsRelation.class); } @@ -2287,7 +2288,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilter_NOT_OnExpandedField() var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(15)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(15)); assertThat(orderNames(topN), contains("salary", "first_name")); var filter = as(topN.child(), Filter.class); assertThat(filter.condition(), instanceOf(And.class)); @@ -2295,7 +2296,7 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilter_NOT_OnExpandedField() topN = as(mvExp.child(), TopN.class); // the filters after mv_expand do not act on the expanded field values, as such the limit 15 is the one being pushed down // otherwise that limit wouldn't have pushed down and the default limit was instead being added by default before mv_expanded - assertThat(topN.limit().fold(), equalTo(10000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000)); assertThat(orderNames(topN), contains("emp_no")); as(topN.child(), EsRelation.class); } @@ -2323,14 +2324,14 @@ public void testAddDefaultLimit_BeforeMvExpand_WithFilterOnExpandedFieldAlias() var keep = as(plan, EsqlProject.class); var topN = as(keep.child(), TopN.class); - assertThat(topN.limit().fold(), equalTo(15)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(15)); assertThat(orderNames(topN), contains("salary", "first_name")); var filter = as(topN.child(), Filter.class); assertThat(filter.condition(), instanceOf(And.class)); var mvExp = as(filter.child(), MvExpand.class); topN = as(mvExp.child(), TopN.class); // the filter uses an alias ("x") to the expanded field ("first_name"), so the default limit is used and not the one provided - assertThat(topN.limit().fold(), equalTo(10000)); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10000)); assertThat(orderNames(topN), contains("gender")); as(topN.child(), EsRelation.class); } @@ -2369,7 +2370,7 @@ public void testSortMvExpandLimit() { var expand = as(plan, MvExpand.class); assertThat(expand.limit(), equalTo(20)); var topN = as(expand.child(), TopN.class); - assertThat(topN.limit().fold(), is(20)); + assertThat(topN.limit().fold(FoldContext.small()), is(20)); var row = as(topN.child(), EsRelation.class); } @@ -2390,7 +2391,7 @@ public void testWhereMvExpand() { var expand = as(plan, MvExpand.class); assertThat(expand.limit(), equalTo(1000)); var limit2 = as(expand.child(), Limit.class); - assertThat(limit2.limit().fold(), is(1000)); + assertThat(limit2.limit().fold(FoldContext.small()), is(1000)); var row = as(limit2.child(), LocalRelation.class); } @@ -2583,7 +2584,7 @@ public void testSimplifyLikeNoWildcard() { assertTrue(filter.condition() instanceof Equals); Equals equals = as(filter.condition(), Equals.class); - assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold()); + assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold(FoldContext.small())); assertTrue(filter.child() instanceof EsRelation); } @@ -2609,7 +2610,7 @@ public void testSimplifyRLikeNoWildcard() { assertTrue(filter.condition() instanceof Equals); Equals equals = as(filter.condition(), Equals.class); - assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold()); + assertEquals(BytesRefs.toBytesRef("foo"), equals.right().fold(FoldContext.small())); assertTrue(filter.child() instanceof EsRelation); } @@ -2773,7 +2774,7 @@ public void testEnrich() { """); var enrich = as(plan, Enrich.class); assertTrue(enrich.policyName().resolved()); - assertThat(enrich.policyName().fold(), is(BytesRefs.toBytesRef("languages_idx"))); + assertThat(enrich.policyName().fold(FoldContext.small()), is(BytesRefs.toBytesRef("languages_idx"))); var eval = as(enrich.child(), Eval.class); var limit = as(eval.child(), Limit.class); as(limit.child(), EsRelation.class); @@ -2819,7 +2820,7 @@ public void testEnrichNotNullFilter() { var filter = as(limit.child(), Filter.class); var enrich = as(filter.child(), Enrich.class); assertTrue(enrich.policyName().resolved()); - assertThat(enrich.policyName().fold(), is(BytesRefs.toBytesRef("languages_idx"))); + assertThat(enrich.policyName().fold(FoldContext.small()), is(BytesRefs.toBytesRef("languages_idx"))); var eval = as(enrich.child(), Eval.class); as(eval.child(), EsRelation.class); } @@ -2940,7 +2941,7 @@ public void testMedianReplacement() { var a = as(aggs.get(0), Alias.class); var per = as(a.child(), Percentile.class); var literal = as(per.percentile(), Literal.class); - assertThat((int) QuantileStates.MEDIAN, is(literal.fold())); + assertThat((int) QuantileStates.MEDIAN, is(literal.value())); assertThat(Expressions.names(agg.groupings()), contains("last_name")); } @@ -2949,7 +2950,7 @@ public void testSplittingInWithFoldableValue() { FieldAttribute fa = getFieldAttribute("foo"); In in = new In(EMPTY, ONE, List.of(TWO, THREE, fa, L(null))); Or expected = new Or(EMPTY, new In(EMPTY, ONE, List.of(TWO, THREE)), new In(EMPTY, ONE, List.of(fa, L(null)))); - assertThat(new SplitInWithFoldableValue().rule(in), equalTo(expected)); + assertThat(new SplitInWithFoldableValue().rule(in, logicalOptimizerCtx), equalTo(expected)); } public void testReplaceFilterWithExact() { @@ -3706,7 +3707,7 @@ private void aggFieldName(Expression exp, Class var alias = as(exp, Alias.class); var af = as(alias.child(), aggType); var field = af.field(); - var name = field.foldable() ? BytesRefs.toString(field.fold()) : Expressions.name(field); + var name = field.foldable() ? BytesRefs.toString(field.fold(FoldContext.small())) : Expressions.name(field); assertThat(name, is(fieldName)); } @@ -4118,7 +4119,7 @@ public void testNestedExpressionsWithGroupingKeyInAggs() { var value = Alias.unwrap(fields.get(0)); var math = as(value, Mod.class); assertThat(Expressions.name(math.left()), is("emp_no")); - assertThat(math.right().fold(), is(2)); + assertThat(math.right().fold(FoldContext.small()), is(2)); // languages + emp_no % 2 var add = as(Alias.unwrap(fields.get(1).canonical()), Add.class); if (add.left() instanceof Mod mod) { @@ -4127,7 +4128,7 @@ public void testNestedExpressionsWithGroupingKeyInAggs() { assertThat(Expressions.name(add.left()), is("languages")); var mod = as(add.right().canonical(), Mod.class); assertThat(Expressions.name(mod.left()), is("emp_no")); - assertThat(mod.right().fold(), is(2)); + assertThat(mod.right().fold(FoldContext.small()), is(2)); } /** @@ -4156,7 +4157,7 @@ public void testNestedExpressionsWithMultiGrouping() { var value = Alias.unwrap(fields.get(0).canonical()); var math = as(value, Mod.class); assertThat(Expressions.name(math.left()), is("emp_no")); - assertThat(math.right().fold(), is(2)); + assertThat(math.right().fold(FoldContext.small()), is(2)); // languages + salary var add = as(Alias.unwrap(fields.get(1).canonical()), Add.class); assertThat(Expressions.name(add.left()), anyOf(is("languages"), is("salary"))); @@ -4173,7 +4174,7 @@ public void testNestedExpressionsWithMultiGrouping() { assertThat(Expressions.name(add3.right()), anyOf(is("salary"), is("languages"))); // emp_no % 2 assertThat(Expressions.name(mod.left()), is("emp_no")); - assertThat(mod.right().fold(), is(2)); + assertThat(mod.right().fold(FoldContext.small()), is(2)); } /** @@ -4611,8 +4612,8 @@ public void testCountOfLiteral() { var mvCoalesce = as(mul.left(), Coalesce.class); assertThat(mvCoalesce.children().size(), equalTo(2)); var mvCount = as(mvCoalesce.children().get(0), MvCount.class); - assertThat(mvCount.fold(), equalTo(2)); - assertThat(mvCoalesce.children().get(1).fold(), equalTo(0)); + assertThat(mvCount.fold(FoldContext.small()), equalTo(2)); + assertThat(mvCoalesce.children().get(1).fold(FoldContext.small()), equalTo(0)); var count = as(mul.right(), ReferenceAttribute.class); assertThat(count.name(), equalTo("$$COUNT$s$0")); @@ -4623,8 +4624,8 @@ public void testCountOfLiteral() { var mvCoalesce_expr = as(mul_expr.left(), Coalesce.class); assertThat(mvCoalesce_expr.children().size(), equalTo(2)); var mvCount_expr = as(mvCoalesce_expr.children().get(0), MvCount.class); - assertThat(mvCount_expr.fold(), equalTo(1)); - assertThat(mvCoalesce_expr.children().get(1).fold(), equalTo(0)); + assertThat(mvCount_expr.fold(FoldContext.small()), equalTo(1)); + assertThat(mvCoalesce_expr.children().get(1).fold(FoldContext.small()), equalTo(0)); var count_expr = as(mul_expr.right(), ReferenceAttribute.class); assertThat(count_expr.name(), equalTo("$$COUNT$s$0")); @@ -4636,7 +4637,7 @@ public void testCountOfLiteral() { assertThat(mvCoalesce_null.children().size(), equalTo(2)); var mvCount_null = as(mvCoalesce_null.children().get(0), MvCount.class); assertThat(mvCount_null.field(), equalTo(NULL)); - assertThat(mvCoalesce_null.children().get(1).fold(), equalTo(0)); + assertThat(mvCoalesce_null.children().get(1).fold(FoldContext.small()), equalTo(0)); var count_null = as(mul_null.right(), ReferenceAttribute.class); assertThat(count_null.name(), equalTo("$$COUNT$s$0")); } @@ -4675,7 +4676,7 @@ public void testSumOfLiteral() { assertThat(s.name(), equalTo("s")); var mul = as(s.child(), Mul.class); var mvSum = as(mul.left(), MvSum.class); - assertThat(mvSum.fold(), equalTo(3)); + assertThat(mvSum.fold(FoldContext.small()), equalTo(3)); var count = as(mul.right(), ReferenceAttribute.class); assertThat(count.name(), equalTo("$$COUNT$s$0")); @@ -4684,7 +4685,7 @@ public void testSumOfLiteral() { assertThat(s_expr.name(), equalTo("s_expr")); var mul_expr = as(s_expr.child(), Mul.class); var mvSum_expr = as(mul_expr.left(), MvSum.class); - assertThat(mvSum_expr.fold(), equalTo(3.14)); + assertThat(mvSum_expr.fold(FoldContext.small()), equalTo(3.14)); var count_expr = as(mul_expr.right(), ReferenceAttribute.class); assertThat(count_expr.name(), equalTo("$$COUNT$s$0")); @@ -4833,7 +4834,7 @@ private static void assertAggOfConstExprs(AggOfLiteralTestCase testCase, List x.foldable() ? new Literal(x.source(), x.fold(), x.dataType()) : x); + be -> LITERALS_ON_THE_RIGHT.rule(be, logicalOptimizerCtx) + ).transformUp(x -> x.foldable() ? new Literal(x.source(), x.fold(FoldContext.small()), x.dataType()) : x); List resolvedFields = fieldAttributeExp.collectFirstChildren(x -> x instanceof FieldAttribute); for (Expression field : resolvedFields) { @@ -5679,8 +5680,7 @@ public void testSimplifyComparisonArithmeticWithFloatsAndDirectionChange() { } private void assertNullLiteral(Expression expression) { - assertEquals(Literal.class, expression.getClass()); - assertNull(expression.fold()); + assertNull(as(expression, Literal.class).value()); } @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108519") @@ -5776,7 +5776,7 @@ public void testReplaceStringCasingWithInsensitiveEqualsEquals() { var filter = as(limit.child(), Filter.class); var insensitive = as(filter.condition(), InsensitiveEquals.class); as(insensitive.left(), FieldAttribute.class); - var bRef = as(insensitive.right().fold(), BytesRef.class); + var bRef = as(insensitive.right().fold(FoldContext.small()), BytesRef.class); assertThat(bRef.utf8ToString(), is(value)); as(filter.child(), EsRelation.class); } @@ -5792,7 +5792,7 @@ public void testReplaceStringCasingWithInsensitiveEqualsNotEquals() { var not = as(filter.condition(), Not.class); var insensitive = as(not.field(), InsensitiveEquals.class); as(insensitive.left(), FieldAttribute.class); - var bRef = as(insensitive.right().fold(), BytesRef.class); + var bRef = as(insensitive.right().fold(FoldContext.small()), BytesRef.class); assertThat(bRef.utf8ToString(), is(value)); as(filter.child(), EsRelation.class); } @@ -5805,7 +5805,7 @@ public void testReplaceStringCasingWithInsensitiveEqualsUnwrap() { var insensitive = as(filter.condition(), InsensitiveEquals.class); var field = as(insensitive.left(), FieldAttribute.class); assertThat(field.fieldName(), is("first_name")); - var bRef = as(insensitive.right().fold(), BytesRef.class); + var bRef = as(insensitive.right().fold(FoldContext.small()), BytesRef.class); assertThat(bRef.utf8ToString(), is("VALÜ")); as(filter.child(), EsRelation.class); } @@ -5856,7 +5856,7 @@ public void testLookupSimple() { var left = as(join.left(), EsqlProject.class); assertThat(left.output().toString(), containsString("int{r}")); var limit = as(left.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); assertThat(join.config().type(), equalTo(JoinTypes.LEFT)); assertThat(join.config().matchFields().stream().map(Object::toString).toList(), matchesList().item(startsWith("int{r}"))); @@ -5925,7 +5925,7 @@ public void testLookupStats() { } var plan = optimizedPlan(query); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); var agg = as(limit.child(), Aggregate.class); assertMap( @@ -6017,7 +6017,7 @@ public void testLookupJoinPushDownFilterOnJoinKeyWithRename() { assertThat(join.config().type(), equalTo(JoinTypes.LEFT)); var project = as(join.left(), Project.class); var limit = as(project.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); var filter = as(limit.child(), Filter.class); // assert that the rename has been undone var op = as(filter.condition(), GreaterThan.class); @@ -6061,7 +6061,7 @@ public void testLookupJoinPushDownFilterOnLeftSideField() { var project = as(join.left(), Project.class); var limit = as(project.child(), Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); var filter = as(limit.child(), Filter.class); var op = as(filter.condition(), GreaterThan.class); var field = as(op.left(), FieldAttribute.class); @@ -6100,7 +6100,7 @@ public void testLookupJoinPushDownDisabledForLookupField() { var plan = optimizedPlan(query); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); var filter = as(limit.child(), Filter.class); var op = as(filter.condition(), Equals.class); @@ -6144,7 +6144,7 @@ public void testLookupJoinPushDownSeparatedForConjunctionBetweenLeftAndRightFiel var plan = optimizedPlan(query); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); // filter kept in place, working on the right side var filter = as(limit.child(), Filter.class); EsqlBinaryComparison op = as(filter.condition(), Equals.class); @@ -6195,7 +6195,7 @@ public void testLookupJoinPushDownDisabledForDisjunctionBetweenLeftAndRightField var plan = optimizedPlan(query); var limit = as(plan, Limit.class); - assertThat(limit.limit().fold(), equalTo(1000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); var filter = as(limit.child(), Filter.class); var or = as(filter.condition(), Or.class); @@ -6289,7 +6289,7 @@ public void testTranslateMixedAggsWithMathWithoutGrouping() { as(addEval.child(), EsRelation.class); assertThat(Expressions.attribute(mul.left()).id(), equalTo(finalAggs.aggregates().get(1).id())); - assertThat(mul.right().fold(), equalTo(1.1)); + assertThat(mul.right().fold(FoldContext.small()), equalTo(1.1)); assertThat(finalAggs.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); Max maxRate = as(Alias.unwrap(finalAggs.aggregates().get(0)), Max.class); @@ -6304,7 +6304,7 @@ public void testTranslateMixedAggsWithMathWithoutGrouping() { ToPartial toPartialMaxCost = as(Alias.unwrap(aggsByTsid.aggregates().get(1)), ToPartial.class); assertThat(Expressions.attribute(toPartialMaxCost.field()).id(), equalTo(addEval.fields().get(0).id())); assertThat(Expressions.attribute(add.left()).name(), equalTo("network.cost")); - assertThat(add.right().fold(), equalTo(0.2)); + assertThat(add.right().fold(FoldContext.small()), equalTo(0.2)); } public void testTranslateMetricsGroupedByOneDimension() { @@ -6533,7 +6533,7 @@ METRICS k8s avg(round(1.05 * rate(network.total_bytes_in))) BY bucket(@timestamp assertThat(Expressions.attribute(finalAgg.groupings().get(1)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); assertThat(Expressions.attribute(mul.left()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); - assertThat(mul.right().fold(), equalTo(1.05)); + assertThat(mul.right().fold(FoldContext.small()), equalTo(1.05)); assertThat(aggsByTsid.aggregateType(), equalTo(Aggregate.AggregateType.METRICS)); Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index a1e399df4233a..2e620256a41ef 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -51,6 +51,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; @@ -164,6 +165,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; import static org.elasticsearch.xpack.esql.EsqlTestUtils.statsForMissingField; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.analyze; @@ -245,7 +247,7 @@ public PhysicalPlanOptimizerTests(String name, Configuration config) { @Before public void init() { parser = new EsqlParser(); - logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext()); physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(config)); EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry(); mapper = new Mapper(); @@ -1117,7 +1119,7 @@ public void testLimit() { var fieldExtract = as(project.child(), FieldExtractExec.class); var source = source(fieldExtract.child()); assertThat(source.estimatedRowSize(), equalTo(allFieldRowSize + Integer.BYTES)); - assertThat(source.limit().fold(), is(10)); + assertThat(source.limit().fold(FoldContext.small()), is(10)); } /** @@ -1199,7 +1201,7 @@ public void testPushLimitToSource() { var leaves = extract.collectLeaves(); assertEquals(1, leaves.size()); var source = as(leaves.get(0), EsQueryExec.class); - assertThat(source.limit().fold(), is(10)); + assertThat(source.limit().fold(FoldContext.small()), is(10)); // extra ints for doc id and emp_no_10 assertThat(source.estimatedRowSize(), equalTo(allFieldRowSize + Integer.BYTES * 2)); } @@ -1246,7 +1248,7 @@ public void testPushLimitAndFilterToSource() { var source = source(extract.child()); assertThat(source.estimatedRowSize(), equalTo(allFieldRowSize + Integer.BYTES * 2)); - assertThat(source.limit().fold(), is(10)); + assertThat(source.limit().fold(FoldContext.small()), is(10)); var rq = as(sv(source.query(), "emp_no"), RangeQueryBuilder.class); assertThat(rq.fieldName(), equalTo("emp_no")); assertThat(rq.from(), equalTo(0)); @@ -2902,14 +2904,14 @@ public void testAvgSurrogateFunctionAfterRenameAndLimit() { var eval = as(project.child(), EvalExec.class); var limit = as(eval.child(), LimitExec.class); assertThat(limit.limit(), instanceOf(Literal.class)); - assertThat(limit.limit().fold(), equalTo(10000)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(10000)); var aggFinal = as(limit.child(), AggregateExec.class); assertThat(aggFinal.getMode(), equalTo(FINAL)); var aggPartial = as(aggFinal.child(), AggregateExec.class); assertThat(aggPartial.getMode(), equalTo(INITIAL)); limit = as(aggPartial.child(), LimitExec.class); assertThat(limit.limit(), instanceOf(Literal.class)); - assertThat(limit.limit().fold(), equalTo(10)); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(10)); var exchange = as(limit.child(), ExchangeExec.class); project = as(exchange.child(), ProjectExec.class); @@ -2918,7 +2920,7 @@ public void testAvgSurrogateFunctionAfterRenameAndLimit() { var fieldExtract = as(project.child(), FieldExtractExec.class); assertThat(Expressions.names(fieldExtract.attributesToExtract()), is(expectedFields)); var source = source(fieldExtract.child()); - assertThat(source.limit().fold(), equalTo(10)); + assertThat(source.limit().fold(FoldContext.small()), equalTo(10)); } /** @@ -4849,15 +4851,15 @@ public void testPushSpatialDistanceMultiEvalToSource() { var alias1 = as(eval.fields().get(0), Alias.class); assertThat(alias1.name(), is("poi")); var poi = as(alias1.child(), Literal.class); - assertThat(poi.fold(), instanceOf(BytesRef.class)); + assertThat(poi.value(), instanceOf(BytesRef.class)); var alias2 = as(eval.fields().get(1), Alias.class); assertThat(alias2.name(), is("distance")); var stDistance = as(alias2.child(), StDistance.class); var location = as(stDistance.left(), FieldAttribute.class); assertThat(location.fieldName(), is("location")); var poiRef = as(stDistance.right(), Literal.class); - assertThat(poiRef.fold(), instanceOf(BytesRef.class)); - assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + assertThat(poiRef.value(), instanceOf(BytesRef.class)); + assertThat(poiRef.value().toString(), is(poi.value().toString())); // Validate the filter condition var and = as(filter.condition(), And.class); @@ -6108,15 +6110,15 @@ public void testPushCompoundTopNDistanceWithCompoundFilterAndCompoundEvalToSourc var alias1 = as(evalExec.fields().get(0), Alias.class); assertThat(alias1.name(), is("poi")); var poi = as(alias1.child(), Literal.class); - assertThat(poi.fold(), instanceOf(BytesRef.class)); + assertThat(poi.value(), instanceOf(BytesRef.class)); var alias2 = as(evalExec.fields().get(1), Alias.class); assertThat(alias2.name(), is("distance")); var stDistance = as(alias2.child(), StDistance.class); var location = as(stDistance.left(), FieldAttribute.class); assertThat(location.fieldName(), is("location")); var poiRef = as(stDistance.right(), Literal.class); - assertThat(poiRef.fold(), instanceOf(BytesRef.class)); - assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + assertThat(poiRef.value(), instanceOf(BytesRef.class)); + assertThat(poiRef.value().toString(), is(poi.value().toString())); extract = as(evalExec.child(), FieldExtractExec.class); assertThat(names(extract.attributesToExtract()), contains("location")); var source = source(extract.child()); @@ -6197,7 +6199,7 @@ public void testPushCompoundTopNDistanceWithDeeplyNestedCompoundEvalToSource() { var alias1 = as(evalExec.fields().get(0), Alias.class); assertThat(alias1.name(), is("poi")); var poi = as(alias1.child(), Literal.class); - assertThat(poi.fold(), instanceOf(BytesRef.class)); + assertThat(poi.value(), instanceOf(BytesRef.class)); var alias4 = as(evalExec.fields().get(3), Alias.class); assertThat(alias4.name(), is("loc2")); as(alias4.child(), FieldAttribute.class); @@ -6210,8 +6212,8 @@ public void testPushCompoundTopNDistanceWithDeeplyNestedCompoundEvalToSource() { var refLocation = as(stDistance.left(), ReferenceAttribute.class); assertThat(refLocation.name(), is("loc3")); var poiRef = as(stDistance.right(), Literal.class); - assertThat(poiRef.fold(), instanceOf(BytesRef.class)); - assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + assertThat(poiRef.value(), instanceOf(BytesRef.class)); + assertThat(poiRef.value().toString(), is(poi.value().toString())); var alias7 = as(evalExec.fields().get(6), Alias.class); assertThat(alias7.name(), is("distance")); as(alias7.child(), ReferenceAttribute.class); @@ -6294,15 +6296,15 @@ public void testPushCompoundTopNDistanceWithCompoundFilterAndNestedCompoundEvalT var alias1 = as(evalExec.fields().get(0), Alias.class); assertThat(alias1.name(), is("poi")); var poi = as(alias1.child(), Literal.class); - assertThat(poi.fold(), instanceOf(BytesRef.class)); + assertThat(poi.value(), instanceOf(BytesRef.class)); var alias2 = as(evalExec.fields().get(1), Alias.class); assertThat(alias2.name(), is("distance")); var stDistance = as(alias2.child(), StDistance.class); var location = as(stDistance.left(), FieldAttribute.class); assertThat(location.fieldName(), is("location")); var poiRef = as(stDistance.right(), Literal.class); - assertThat(poiRef.fold(), instanceOf(BytesRef.class)); - assertThat(poiRef.fold().toString(), is(poi.fold().toString())); + assertThat(poiRef.value(), instanceOf(BytesRef.class)); + assertThat(poiRef.value().toString(), is(poi.value().toString())); extract = as(evalExec.child(), FieldExtractExec.class); assertThat(names(extract.attributesToExtract()), contains("location")); var source = source(extract.child()); @@ -6834,7 +6836,7 @@ public void testManyEnrich() { var fragment = as(exchange.child(), FragmentExec.class); var partialTopN = as(fragment.fragment(), TopN.class); var enrich2 = as(partialTopN.child(), Enrich.class); - assertThat(BytesRefs.toString(enrich2.policyName().fold()), equalTo("departments")); + assertThat(BytesRefs.toString(enrich2.policyName().fold(FoldContext.small())), equalTo("departments")); assertThat(enrich2.mode(), equalTo(Enrich.Mode.ANY)); var eval = as(enrich2.child(), Eval.class); as(eval.child(), EsRelation.class); @@ -6860,7 +6862,7 @@ public void testManyEnrich() { var fragment = as(exchange.child(), FragmentExec.class); var partialTopN = as(fragment.fragment(), TopN.class); var enrich2 = as(partialTopN.child(), Enrich.class); - assertThat(BytesRefs.toString(enrich2.policyName().fold()), equalTo("departments")); + assertThat(BytesRefs.toString(enrich2.policyName().fold(FoldContext.small())), equalTo("departments")); assertThat(enrich2.mode(), equalTo(Enrich.Mode.ANY)); var eval = as(enrich2.child(), Eval.class); as(eval.child(), EsRelation.class); @@ -7452,7 +7454,7 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP // The TopN needs an estimated row size for the planner to work var plans = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode(EstimatesRowSize.estimateRowSize(0, plan), config); plan = useDataNodePlan ? plans.v2() : plans.v1(); - plan = PlannerUtils.localPlan(List.of(), config, plan); + plan = PlannerUtils.localPlan(List.of(), config, FoldContext.small(), plan); LocalExecutionPlanner planner = new LocalExecutionPlanner( "test", "", @@ -7465,10 +7467,10 @@ private LocalExecutionPlanner.LocalExecutionPlan physicalOperationsFromPhysicalP new ExchangeSinkHandler(null, 10, () -> 10), null, null, - new EsPhysicalOperationProviders(List.of(), null) + new EsPhysicalOperationProviders(FoldContext.small(), List.of(), null) ); - return planner.plan(plan); + return planner.plan(FoldContext.small(), plan); } private List> findFieldNamesInLookupJoinDescription(LocalExecutionPlanner.LocalExecutionPlan physicalOperations) { @@ -7565,7 +7567,7 @@ public void testReductionPlanForTopN() { PhysicalPlan reduction = PlannerUtils.reductionPlan(plans.v2()); TopNExec reductionTopN = as(reduction, TopNExec.class); assertThat(reductionTopN.estimatedRowSize(), equalTo(allFieldRowSize)); - assertThat(reductionTopN.limit().fold(), equalTo(limit)); + assertThat(reductionTopN.limit().fold(FoldContext.small()), equalTo(limit)); } public void testReductionPlanForAggs() { @@ -7721,7 +7723,7 @@ private PhysicalPlan optimizedPlan(PhysicalPlan plan, SearchStats searchStats) { // individually hence why here the plan is kept as is var l = p.transformUp(FragmentExec.class, fragment -> { - var localPlan = PlannerUtils.localPlan(config, fragment, searchStats); + var localPlan = PlannerUtils.localPlan(config, FoldContext.small(), fragment, searchStats); return EstimatesRowSize.estimateRowSize(fragment.estimatedRowSize(), localPlan); }); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java index 9fe479dbb8625..e6a7d110f8c09 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/TestPlannerOptimizer.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.analysis.Analyzer; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -26,7 +27,7 @@ public class TestPlannerOptimizer { private final Configuration config; public TestPlannerOptimizer(Configuration config, Analyzer analyzer) { - this(config, analyzer, new LogicalPlanOptimizer(new LogicalOptimizerContext(config))); + this(config, analyzer, new LogicalPlanOptimizer(new LogicalOptimizerContext(config, FoldContext.small()))); } public TestPlannerOptimizer(Configuration config, Analyzer analyzer, LogicalPlanOptimizer logicalOptimizer) { @@ -61,8 +62,13 @@ private PhysicalPlan optimizedPlan(PhysicalPlan plan, SearchStats searchStats) { // this is of no use in the unit tests, which checks the plan as a whole instead of each // individually hence why here the plan is kept as is - var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(config, searchStats)); - var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(config, searchStats), true); + var logicalTestOptimizer = new LocalLogicalPlanOptimizer( + new LocalLogicalOptimizerContext(config, FoldContext.small(), searchStats) + ); + var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer( + new LocalPhysicalOptimizerContext(config, FoldContext.small(), searchStats), + true + ); var l = PlannerUtils.localPlan(physicalPlan, logicalTestOptimizer, physicalTestOptimizer); // handle local reduction alignment diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LogicalOptimizerContextTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LogicalOptimizerContextTests.java new file mode 100644 index 0000000000000..5d2fec0fc8181 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LogicalOptimizerContextTests.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.EqualsHashCodeTestUtils; +import org.elasticsearch.xpack.esql.ConfigurationTestUtils; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; +import org.elasticsearch.xpack.esql.session.Configuration; + +import static org.hamcrest.Matchers.equalTo; + +public class LogicalOptimizerContextTests extends ESTestCase { + public void testToString() { + // Random looking numbers for FoldContext are indeed random. Just so we have consistent numbers to assert on in toString. + LogicalOptimizerContext ctx = new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG, new FoldContext(102)); + ctx.foldCtx().trackAllocation(Source.EMPTY, 99); + assertThat( + ctx.toString(), + equalTo("LogicalOptimizerContext[configuration=" + EsqlTestUtils.TEST_CFG + ", foldCtx=FoldContext[3/102]]") + ); + } + + public void testEqualsAndHashCode() { + EqualsHashCodeTestUtils.checkEqualsAndHashCode(randomLogicalOptimizerContext(), this::copy, this::mutate); + } + + private LogicalOptimizerContext randomLogicalOptimizerContext() { + return new LogicalOptimizerContext(ConfigurationTestUtils.randomConfiguration(), randomFoldContext()); + } + + private LogicalOptimizerContext copy(LogicalOptimizerContext c) { + return new LogicalOptimizerContext(c.configuration(), c.foldCtx()); + } + + private LogicalOptimizerContext mutate(LogicalOptimizerContext c) { + Configuration configuration = c.configuration(); + FoldContext foldCtx = c.foldCtx(); + if (randomBoolean()) { + configuration = randomValueOtherThan(configuration, ConfigurationTestUtils::randomConfiguration); + } else { + foldCtx = randomValueOtherThan(foldCtx, this::randomFoldContext); + } + return new LogicalOptimizerContext(configuration, foldCtx); + } + + private FoldContext randomFoldContext() { + FoldContext ctx = new FoldContext(randomNonNegativeLong()); + if (randomBoolean()) { + ctx.trackAllocation(Source.EMPTY, randomLongBetween(0, ctx.initialAllowedBytes())); + } + return ctx; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java index 08c8612d8097c..c0c145aee5382 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanFunctionEqualsEliminationTests.java @@ -22,25 +22,26 @@ import static java.util.Arrays.asList; import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; public class BooleanFunctionEqualsEliminationTests extends ESTestCase { + private Expression booleanFunctionEqualElimination(BinaryComparison e) { + return new BooleanFunctionEqualsElimination().rule(e, unboundLogicalOptimizerContext()); + } public void testBoolEqualsSimplificationOnExpressions() { - BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); Expression exp = new GreaterThan(EMPTY, getFieldAttribute(), new Literal(EMPTY, 0, DataType.INTEGER), null); - assertEquals(exp, s.rule(new Equals(EMPTY, exp, TRUE))); + assertEquals(exp, booleanFunctionEqualElimination(new Equals(EMPTY, exp, TRUE))); // TODO: Replace use of QL Not with ESQL Not - assertEquals(new Not(EMPTY, exp), s.rule(new Equals(EMPTY, exp, FALSE))); + assertEquals(new Not(EMPTY, exp), booleanFunctionEqualElimination(new Equals(EMPTY, exp, FALSE))); } public void testBoolEqualsSimplificationOnFields() { - BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); - FieldAttribute field = getFieldAttribute(); List comparisons = asList( @@ -55,7 +56,7 @@ public void testBoolEqualsSimplificationOnFields() { ); for (BinaryComparison comparison : comparisons) { - assertEquals(comparison, s.rule(comparison)); + assertEquals(comparison, booleanFunctionEqualElimination(comparison)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java index 3b1f8cfc83af3..5b4bf806518de 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/BooleanSimplificationTests.java @@ -9,9 +9,11 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; @@ -20,33 +22,31 @@ public class BooleanSimplificationTests extends ESTestCase { private static final Expression DUMMY_EXPRESSION = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0); - public void testBoolSimplifyOr() { - BooleanSimplification simplification = new BooleanSimplification(); + private Expression booleanSimplification(ScalarFunction e) { + return new BooleanSimplification().rule(e, unboundLogicalOptimizerContext()); + } - assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, TRUE))); - assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, DUMMY_EXPRESSION))); - assertEquals(TRUE, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, TRUE))); + public void testBoolSimplifyOr() { + assertEquals(TRUE, booleanSimplification(new Or(EMPTY, TRUE, TRUE))); + assertEquals(TRUE, booleanSimplification(new Or(EMPTY, TRUE, DUMMY_EXPRESSION))); + assertEquals(TRUE, booleanSimplification(new Or(EMPTY, DUMMY_EXPRESSION, TRUE))); - assertEquals(FALSE, simplification.rule(new Or(EMPTY, FALSE, FALSE))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, FALSE, DUMMY_EXPRESSION))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, FALSE))); + assertEquals(FALSE, booleanSimplification(new Or(EMPTY, FALSE, FALSE))); + assertEquals(DUMMY_EXPRESSION, booleanSimplification(new Or(EMPTY, FALSE, DUMMY_EXPRESSION))); + assertEquals(DUMMY_EXPRESSION, booleanSimplification(new Or(EMPTY, DUMMY_EXPRESSION, FALSE))); } public void testBoolSimplifyAnd() { - BooleanSimplification simplification = new BooleanSimplification(); + assertEquals(TRUE, booleanSimplification(new And(EMPTY, TRUE, TRUE))); + assertEquals(DUMMY_EXPRESSION, booleanSimplification(new And(EMPTY, TRUE, DUMMY_EXPRESSION))); + assertEquals(DUMMY_EXPRESSION, booleanSimplification(new And(EMPTY, DUMMY_EXPRESSION, TRUE))); - assertEquals(TRUE, simplification.rule(new And(EMPTY, TRUE, TRUE))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, TRUE, DUMMY_EXPRESSION))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, TRUE))); - - assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, FALSE))); - assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, DUMMY_EXPRESSION))); - assertEquals(FALSE, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, FALSE))); + assertEquals(FALSE, booleanSimplification(new And(EMPTY, FALSE, FALSE))); + assertEquals(FALSE, booleanSimplification(new And(EMPTY, FALSE, DUMMY_EXPRESSION))); + assertEquals(FALSE, booleanSimplification(new And(EMPTY, DUMMY_EXPRESSION, FALSE))); } public void testBoolCommonFactorExtraction() { - BooleanSimplification simplification = new BooleanSimplification(); - Expression a1 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); Expression a2 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); Expression b = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 2); @@ -55,7 +55,7 @@ public void testBoolCommonFactorExtraction() { Or actual = new Or(EMPTY, new And(EMPTY, a1, b), new And(EMPTY, a2, c)); And expected = new And(EMPTY, a1, new Or(EMPTY, b, c)); - assertEquals(expected, simplification.rule(actual)); + assertEquals(expected, booleanSimplification(actual)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java index d388369e0b167..a0d23731ae82d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineBinaryComparisonsTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; @@ -37,6 +38,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; @@ -45,19 +47,17 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; public class CombineBinaryComparisonsTests extends ESTestCase { - - private static final Expression DUMMY_EXPRESSION = - new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0); + private Expression combine(BinaryLogic e) { + return new CombineBinaryComparisons().rule(e, unboundLogicalOptimizerContext()); + } public void testCombineBinaryComparisonsNotComparable() { FieldAttribute fa = getFieldAttribute(); LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX); LessThan lt = lessThanOf(fa, FALSE); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - And and = new And(EMPTY, lte, lt); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals(exp, and); } @@ -67,9 +67,7 @@ public void testCombineBinaryComparisonsUpper() { LessThanOrEqual lte = lessThanOrEqualOf(fa, SIX); LessThan lt = lessThanOf(fa, FIVE); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, lte, lt)); + Expression exp = combine(new And(EMPTY, lte, lt)); assertEquals(LessThan.class, exp.getClass()); LessThan r = (LessThan) exp; assertEquals(FIVE, r.right()); @@ -81,9 +79,7 @@ public void testCombineBinaryComparisonsLower() { GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, SIX); GreaterThan gt = greaterThanOf(fa, FIVE); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, gte, gt)); + Expression exp = combine(new And(EMPTY, gte, gt)); assertEquals(GreaterThanOrEqual.class, exp.getClass()); GreaterThanOrEqual r = (GreaterThanOrEqual) exp; assertEquals(SIX, r.right()); @@ -95,9 +91,7 @@ public void testCombineBinaryComparisonsInclude() { GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, FIVE); GreaterThan gt = greaterThanOf(fa, FIVE); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, gte, gt)); + Expression exp = combine(new And(EMPTY, gte, gt)); assertEquals(GreaterThan.class, exp.getClass()); GreaterThan r = (GreaterThan) exp; assertEquals(FIVE, r.right()); @@ -111,9 +105,7 @@ public void testCombineMultipleBinaryComparisons() { LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7)); LessThan lt = lessThanOf(fa, SIX); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - - Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, gt, new And(EMPTY, lt, lte)))); + Expression exp = combine(new And(EMPTY, gte, new And(EMPTY, gt, new And(EMPTY, lt, lte)))); assertEquals(And.class, exp.getClass()); And and = (And) exp; assertEquals(gt, and.left()); @@ -128,10 +120,8 @@ public void testCombineMixedMultipleBinaryComparisons() { LessThanOrEqual lte = lessThanOrEqualOf(fa, L(7)); Expression ne = notEqualsOf(fa, FIVE); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - // TRUE AND a != 5 AND 4 < a <= 7 - Expression exp = rule.rule(new And(EMPTY, gte, new And(EMPTY, TRUE, new And(EMPTY, gt, new And(EMPTY, ne, lte))))); + Expression exp = combine(new And(EMPTY, gte, new And(EMPTY, TRUE, new And(EMPTY, gt, new And(EMPTY, ne, lte))))); assertEquals(And.class, exp.getClass()); And and = ((And) exp); assertEquals(And.class, and.right().getClass()); @@ -150,8 +140,7 @@ public void testCombineComparisonsIntoRange() { GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE); LessThan lt = lessThanOf(fa, FIVE); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(new And(EMPTY, gte, lt)); + Expression exp = combine(new And(EMPTY, gte, lt)); assertEquals(And.class, exp.getClass()); And and = (And) exp; @@ -167,8 +156,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndGt3() { GreaterThan gt = greaterThanOf(fa, THREE); And and = new And(EMPTY, neq, gt); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals(gt, exp); } @@ -180,8 +168,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndGte2() { GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO); And and = new And(EMPTY, neq, gte); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals(GreaterThan.class, exp.getClass()); GreaterThan gt = (GreaterThan) exp; assertEquals(TWO, gt.right()); @@ -195,8 +182,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndGte1() { GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, ONE); And and = new And(EMPTY, neq, gte); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals(And.class, exp.getClass()); // can't optimize } @@ -208,8 +194,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndLte3() { LessThanOrEqual lte = lessThanOrEqualOf(fa, THREE); And and = new And(EMPTY, neq, lte); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals(and, exp); // can't optimize } @@ -221,8 +206,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndLte2() { LessThanOrEqual lte = lessThanOrEqualOf(fa, TWO); And and = new And(EMPTY, neq, lte); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals(LessThan.class, exp.getClass()); LessThan lt = (LessThan) exp; assertEquals(TWO, lt.right()); @@ -236,8 +220,7 @@ public void testCombineBinaryComparisonsConjunction_Neq2AndLte1() { LessThanOrEqual lte = lessThanOrEqualOf(fa, ONE); And and = new And(EMPTY, neq, lte); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals(lte, exp); } @@ -251,8 +234,7 @@ public void testCombineBinaryComparisonsDisjunctionNotComparable() { Or or = new Or(EMPTY, gt1, gt2); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); + Expression exp = combine(or); assertEquals(exp, or); } @@ -266,8 +248,7 @@ public void testCombineBinaryComparisonsDisjunctionLowerBound() { Or or = new Or(EMPTY, gt1, new Or(EMPTY, gt2, gt3)); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); + Expression exp = combine(or); assertEquals(GreaterThan.class, exp.getClass()); GreaterThan gt = (GreaterThan) exp; @@ -284,8 +265,7 @@ public void testCombineBinaryComparisonsDisjunctionIncludeLowerBounds() { Or or = new Or(EMPTY, new Or(EMPTY, gt1, gt2), gte3); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); + Expression exp = combine(or); assertEquals(GreaterThan.class, exp.getClass()); GreaterThan gt = (GreaterThan) exp; @@ -302,8 +282,7 @@ public void testCombineBinaryComparisonsDisjunctionUpperBound() { Or or = new Or(EMPTY, new Or(EMPTY, lt1, lt2), lt3); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); + Expression exp = combine(or); assertEquals(LessThan.class, exp.getClass()); LessThan lt = (LessThan) exp; @@ -320,8 +299,7 @@ public void testCombineBinaryComparisonsDisjunctionIncludeUpperBounds() { Or or = new Or(EMPTY, lt2, new Or(EMPTY, lte2, lt1)); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); + Expression exp = combine(or); assertEquals(LessThanOrEqual.class, exp.getClass()); LessThanOrEqual lte = (LessThanOrEqual) exp; @@ -340,8 +318,7 @@ public void testCombineBinaryComparisonsDisjunctionOfLowerAndUpperBounds() { Or or = new Or(EMPTY, new Or(EMPTY, lt2, gt3), new Or(EMPTY, lt1, gt4)); - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(or); + Expression exp = combine(or); assertEquals(Or.class, exp.getClass()); Or ro = (Or) exp; @@ -367,7 +344,7 @@ public void testBooleanSimplificationCommonExpressionSubstraction() { And right = new And(EMPTY, a2, common); Or or = new Or(EMPTY, left, right); - Expression exp = new BooleanSimplification().rule(or); + Expression exp = new BooleanSimplification().rule(or, unboundLogicalOptimizerContext()); assertEquals(new And(EMPTY, common, new Or(EMPTY, a1, a2)), exp); } @@ -391,8 +368,7 @@ public void testBinaryComparisonAndOutOfRangeNotEqualsDifferentFields() { ); for (And and : testCases) { - CombineBinaryComparisons rule = new CombineBinaryComparisons(); - Expression exp = rule.rule(and); + Expression exp = combine(and); assertEquals("Rule should not have transformed [" + and.nodeString() + "]", and, exp); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java index 043d18dac9fd4..bb5f2fd3505e9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/CombineDisjunctionsTests.java @@ -38,16 +38,24 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.hamcrest.Matchers.contains; public class CombineDisjunctionsTests extends ESTestCase { + private Expression combineDisjunctions(Or e) { + return new CombineDisjunctions().rule(e, unboundLogicalOptimizerContext()); + } + + private LogicalPlan combineDisjunctions(LogicalPlan l) { + return new CombineDisjunctions().apply(l, unboundLogicalOptimizerContext()); + } public void testTwoEqualsWithOr() { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); - Expression e = new CombineDisjunctions().rule(or); + Expression e = combineDisjunctions(or); assertEquals(In.class, e.getClass()); In in = (In) e; assertEquals(fa, in.value()); @@ -58,7 +66,7 @@ public void testTwoEqualsWithSameValue() { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE)); - Expression e = new CombineDisjunctions().rule(or); + Expression e = combineDisjunctions(or); assertEquals(Equals.class, e.getClass()); Equals eq = (Equals) e; assertEquals(fa, eq.left()); @@ -69,7 +77,7 @@ public void testOneEqualsOneIn() { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, List.of(TWO))); - Expression e = new CombineDisjunctions().rule(or); + Expression e = combineDisjunctions(or); assertEquals(In.class, e.getClass()); In in = (In) e; assertEquals(fa, in.value()); @@ -80,7 +88,7 @@ public void testOneEqualsOneInWithSameValue() { FieldAttribute fa = getFieldAttribute(); Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO))); - Expression e = new CombineDisjunctions().rule(or); + Expression e = combineDisjunctions(or); assertEquals(In.class, e.getClass()); In in = (In) e; assertEquals(fa, in.value()); @@ -92,7 +100,7 @@ public void testSingleValueInToEquals() { Equals equals = equalsOf(fa, ONE); Or or = new Or(EMPTY, equals, new In(EMPTY, fa, List.of(ONE))); - Expression e = new CombineDisjunctions().rule(or); + Expression e = combineDisjunctions(or); assertEquals(equals, e); } @@ -101,7 +109,7 @@ public void testEqualsBehindAnd() { And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); Filter dummy = new Filter(EMPTY, relation(), and); - LogicalPlan transformed = new CombineDisjunctions().apply(dummy); + LogicalPlan transformed = combineDisjunctions(dummy); assertSame(dummy, transformed); assertEquals(and, ((Filter) transformed).condition()); } @@ -111,7 +119,7 @@ public void testTwoEqualsDifferentFields() { FieldAttribute fieldTwo = getFieldAttribute("TWO"); Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO)); - Expression e = new CombineDisjunctions().rule(or); + Expression e = combineDisjunctions(or); assertEquals(or, e); } @@ -120,7 +128,7 @@ public void testMultipleIn() { Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), new In(EMPTY, fa, List.of(TWO))); Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); - Expression e = new CombineDisjunctions().rule(secondOr); + Expression e = combineDisjunctions(secondOr); assertEquals(In.class, e.getClass()); In in = (In) e; assertEquals(fa, in.value()); @@ -132,7 +140,7 @@ public void testOrWithNonCombinableExpressions() { Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), lessThanOf(fa, TWO)); Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); - Expression e = new CombineDisjunctions().rule(secondOr); + Expression e = combineDisjunctions(secondOr); assertEquals(Or.class, e.getClass()); Or or = (Or) e; assertEquals(or.left(), firstOr.right()); @@ -160,7 +168,7 @@ public void testCombineCIDRMatch() { cidrs.add(new CIDRMatch(EMPTY, faa, ipa2)); cidrs.add(new CIDRMatch(EMPTY, fab, ipb2)); Or oldOr = (Or) Predicates.combineOr(cidrs); - Expression e = new CombineDisjunctions().rule(oldOr); + Expression e = combineDisjunctions(oldOr); assertEquals(Or.class, e.getClass()); Or newOr = (Or) e; assertEquals(CIDRMatch.class, newOr.left().getClass()); @@ -211,7 +219,7 @@ public void testCombineCIDRMatchEqualsIns() { Or oldOr = (Or) Predicates.combineOr(all); - Expression e = new CombineDisjunctions().rule(oldOr); + Expression e = combineDisjunctions(oldOr); assertEquals(Or.class, e.getClass()); Or newOr = (Or) e; assertEquals(Or.class, newOr.left().getClass()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java index 01af91271e1ba..8a8585b8d0ab5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ConstantFoldingTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; @@ -45,68 +46,77 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; public class ConstantFoldingTests extends ESTestCase { + private Expression constantFolding(Expression e) { + return new ConstantFolding().rule(e, unboundLogicalOptimizerContext()); + } public void testConstantFolding() { Expression exp = new Add(EMPTY, TWO, THREE); assertTrue(exp.foldable()); - Expression result = new ConstantFolding().rule(exp); - assertTrue(result instanceof Literal); - assertEquals(5, ((Literal) result).value()); + Expression result = constantFolding(exp); + assertEquals(5, as(result, Literal.class).value()); // check now with an alias - result = new ConstantFolding().rule(new Alias(EMPTY, "a", exp)); + result = constantFolding(new Alias(EMPTY, "a", exp)); assertEquals("a", Expressions.name(result)); assertEquals(Alias.class, result.getClass()); } public void testConstantFoldingBinaryComparison() { - assertEquals(FALSE, new ConstantFolding().rule(greaterThanOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(greaterThanOrEqualOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(equalsOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(notEqualsOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(lessThanOrEqualOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(lessThanOf(TWO, THREE)).canonical()); + assertEquals(FALSE, constantFolding(greaterThanOf(TWO, THREE)).canonical()); + assertEquals(FALSE, constantFolding(greaterThanOrEqualOf(TWO, THREE)).canonical()); + assertEquals(FALSE, constantFolding(equalsOf(TWO, THREE)).canonical()); + assertEquals(TRUE, constantFolding(notEqualsOf(TWO, THREE)).canonical()); + assertEquals(TRUE, constantFolding(lessThanOrEqualOf(TWO, THREE)).canonical()); + assertEquals(TRUE, constantFolding(lessThanOf(TWO, THREE)).canonical()); } public void testConstantFoldingBinaryLogic() { - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical()); + assertEquals(FALSE, constantFolding(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical()); + assertEquals(TRUE, constantFolding(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical()); } public void testConstantFoldingBinaryLogic_WithNullHandling() { - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical().nullable()); - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, NULL, FALSE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, FALSE, NULL)).canonical()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical().nullable()); - - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, TRUE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, TRUE, NULL)).canonical()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical().nullable()); + assertEquals(Nullability.TRUE, constantFolding(new And(EMPTY, NULL, TRUE)).canonical().nullable()); + assertEquals(Nullability.TRUE, constantFolding(new And(EMPTY, TRUE, NULL)).canonical().nullable()); + assertEquals(FALSE, constantFolding(new And(EMPTY, NULL, FALSE)).canonical()); + assertEquals(FALSE, constantFolding(new And(EMPTY, FALSE, NULL)).canonical()); + assertEquals(Nullability.TRUE, constantFolding(new And(EMPTY, NULL, NULL)).canonical().nullable()); + + assertEquals(TRUE, constantFolding(new Or(EMPTY, NULL, TRUE)).canonical()); + assertEquals(TRUE, constantFolding(new Or(EMPTY, TRUE, NULL)).canonical()); + assertEquals(Nullability.TRUE, constantFolding(new Or(EMPTY, NULL, FALSE)).canonical().nullable()); + assertEquals(Nullability.TRUE, constantFolding(new Or(EMPTY, FALSE, NULL)).canonical().nullable()); + assertEquals(Nullability.TRUE, constantFolding(new Or(EMPTY, NULL, NULL)).canonical().nullable()); } public void testConstantFoldingRange() { - assertEquals(true, new ConstantFolding().rule(rangeOf(FIVE, FIVE, true, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); - assertEquals(false, new ConstantFolding().rule(rangeOf(FIVE, FIVE, false, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); + assertEquals( + true, + constantFolding(rangeOf(FIVE, FIVE, true, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold(FoldContext.small()) + ); + assertEquals( + false, + constantFolding(rangeOf(FIVE, FIVE, false, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold(FoldContext.small()) + ); } public void testConstantNot() { - assertEquals(FALSE, new ConstantFolding().rule(new Not(EMPTY, TRUE))); - assertEquals(TRUE, new ConstantFolding().rule(new Not(EMPTY, FALSE))); + assertEquals(FALSE, constantFolding(new Not(EMPTY, TRUE))); + assertEquals(TRUE, constantFolding(new Not(EMPTY, FALSE))); } public void testConstantFoldingLikes() { - assertEquals(TRUE, new ConstantFolding().rule(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical()); + assertEquals(TRUE, constantFolding(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical()); + assertEquals(TRUE, constantFolding(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical()); } public void testArithmeticFolding() { @@ -125,7 +135,7 @@ public void testFoldRange() { Expression value = new Literal(EMPTY, 12, DataType.INTEGER); Range range = new Range(EMPTY, value, lowerBound, randomBoolean(), upperBound, randomBoolean(), randomZone()); - Expression folded = new ConstantFolding().rule(range); + Expression folded = constantFolding(range); assertTrue((Boolean) as(folded, Literal.class).value()); } @@ -156,16 +166,15 @@ public void testFoldRangeWithInvalidBoundaries() { // Just applying this to the range directly won't perform a transformDown. LogicalPlan filter = new Filter(EMPTY, emptySource(), range); - Filter foldedOnce = as(new ConstantFolding().apply(filter), Filter.class); + Filter foldedOnce = as(new ConstantFolding().apply(filter, unboundLogicalOptimizerContext()), Filter.class); // We need to run the rule twice, because during the first run only the boundaries can be folded - the range doesn't know it's // foldable, yet. - Filter foldedTwice = as(new ConstantFolding().apply(foldedOnce), Filter.class); + Filter foldedTwice = as(new ConstantFolding().apply(foldedOnce, unboundLogicalOptimizerContext()), Filter.class); assertFalse((Boolean) as(foldedTwice.condition(), Literal.class).value()); } - private static Object foldOperator(BinaryOperator b) { - return ((Literal) new ConstantFolding().rule(b)).value(); + private Object foldOperator(BinaryOperator b) { + return ((Literal) constantFolding(b)).value(); } - } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java index ae31576184938..252b25a214bb8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; @@ -67,9 +68,11 @@ import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.L; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.as; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; @@ -85,28 +88,27 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.VERSION; public class FoldNullTests extends ESTestCase { + private Expression foldNull(Expression e) { + return new FoldNull().rule(e, unboundLogicalOptimizerContext()); + } public void testBasicNullFolding() { - FoldNull rule = new FoldNull(); - assertNullLiteral(rule.rule(new Add(EMPTY, L(randomInt()), Literal.NULL))); - assertNullLiteral(rule.rule(new Round(EMPTY, Literal.NULL, null))); - assertNullLiteral(rule.rule(new Pow(EMPTY, Literal.NULL, Literal.NULL))); - assertNullLiteral(rule.rule(new DateFormat(EMPTY, Literal.NULL, Literal.NULL, null))); - assertNullLiteral(rule.rule(new DateParse(EMPTY, Literal.NULL, Literal.NULL))); - assertNullLiteral(rule.rule(new DateTrunc(EMPTY, Literal.NULL, Literal.NULL))); - assertNullLiteral(rule.rule(new Substring(EMPTY, Literal.NULL, Literal.NULL, Literal.NULL))); + assertNullLiteral(foldNull(new Add(EMPTY, L(randomInt()), Literal.NULL))); + assertNullLiteral(foldNull(new Round(EMPTY, Literal.NULL, null))); + assertNullLiteral(foldNull(new Pow(EMPTY, Literal.NULL, Literal.NULL))); + assertNullLiteral(foldNull(new DateFormat(EMPTY, Literal.NULL, Literal.NULL, null))); + assertNullLiteral(foldNull(new DateParse(EMPTY, Literal.NULL, Literal.NULL))); + assertNullLiteral(foldNull(new DateTrunc(EMPTY, Literal.NULL, Literal.NULL))); + assertNullLiteral(foldNull(new Substring(EMPTY, Literal.NULL, Literal.NULL, Literal.NULL))); } public void testNullFoldingIsNotNull() { - FoldNull foldNull = new FoldNull(); - assertEquals(true, foldNull.rule(new IsNotNull(EMPTY, TRUE)).fold()); - assertEquals(false, foldNull.rule(new IsNotNull(EMPTY, NULL)).fold()); + assertEquals(true, foldNull(new IsNotNull(EMPTY, TRUE)).fold(FoldContext.small())); + assertEquals(false, foldNull(new IsNotNull(EMPTY, NULL)).fold(FoldContext.small())); } @SuppressWarnings("unchecked") public void testNullFoldingDoesNotApplyOnAbstractMultivalueFunction() throws Exception { - FoldNull rule = new FoldNull(); - List> items = List.of( MvDedupe.class, MvFirst.class, @@ -119,119 +121,112 @@ public void testNullFoldingDoesNotApplyOnAbstractMultivalueFunction() throws Exc for (Class clazz : items) { Constructor ctor = clazz.getConstructor(Source.class, Expression.class); AbstractMultivalueFunction conditionalFunction = ctor.newInstance(EMPTY, getFieldAttribute("a")); - assertEquals(conditionalFunction, rule.rule(conditionalFunction)); + assertEquals(conditionalFunction, foldNull(conditionalFunction)); conditionalFunction = ctor.newInstance(EMPTY, NULL); - assertEquals(NULL, rule.rule(conditionalFunction)); + assertEquals(NULL, foldNull(conditionalFunction)); } // avg and count ar different just because they know the return type in advance (all the others infer the type from the input) MvAvg avg = new MvAvg(EMPTY, getFieldAttribute("a")); - assertEquals(avg, rule.rule(avg)); + assertEquals(avg, foldNull(avg)); avg = new MvAvg(EMPTY, NULL); - assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(avg)); + assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(avg)); MvCount count = new MvCount(EMPTY, getFieldAttribute("a")); - assertEquals(count, rule.rule(count)); + assertEquals(count, foldNull(count)); count = new MvCount(EMPTY, NULL); - assertEquals(new Literal(EMPTY, null, INTEGER), rule.rule(count)); + assertEquals(new Literal(EMPTY, null, INTEGER), foldNull(count)); } public void testNullFoldingIsNull() { - FoldNull foldNull = new FoldNull(); - assertEquals(true, foldNull.rule(new IsNull(EMPTY, NULL)).fold()); - assertEquals(false, foldNull.rule(new IsNull(EMPTY, TRUE)).fold()); + assertEquals(true, foldNull(new IsNull(EMPTY, NULL)).fold(FoldContext.small())); + assertEquals(false, foldNull(new IsNull(EMPTY, TRUE)).fold(FoldContext.small())); } public void testGenericNullableExpression() { FoldNull rule = new FoldNull(); // arithmetic - assertNullLiteral(rule.rule(new Add(EMPTY, getFieldAttribute("a"), NULL))); + assertNullLiteral(foldNull(new Add(EMPTY, getFieldAttribute("a"), NULL))); // comparison - assertNullLiteral(rule.rule(greaterThanOf(getFieldAttribute("a"), NULL))); + assertNullLiteral(foldNull(greaterThanOf(getFieldAttribute("a"), NULL))); // regex - assertNullLiteral(rule.rule(new RLike(EMPTY, NULL, new RLikePattern("123")))); + assertNullLiteral(foldNull(new RLike(EMPTY, NULL, new RLikePattern("123")))); // date functions - assertNullLiteral(rule.rule(new DateExtract(EMPTY, NULL, NULL, configuration("")))); + assertNullLiteral(foldNull(new DateExtract(EMPTY, NULL, NULL, configuration("")))); // math functions - assertNullLiteral(rule.rule(new Cos(EMPTY, NULL))); + assertNullLiteral(foldNull(new Cos(EMPTY, NULL))); // string functions - assertNullLiteral(rule.rule(new LTrim(EMPTY, NULL))); + assertNullLiteral(foldNull(new LTrim(EMPTY, NULL))); // spatial - assertNullLiteral(rule.rule(new SpatialCentroid(EMPTY, NULL))); + assertNullLiteral(foldNull(new SpatialCentroid(EMPTY, NULL))); // ip - assertNullLiteral(rule.rule(new CIDRMatch(EMPTY, NULL, List.of(NULL)))); + assertNullLiteral(foldNull(new CIDRMatch(EMPTY, NULL, List.of(NULL)))); // conversion - assertNullLiteral(rule.rule(new ToString(EMPTY, NULL))); + assertNullLiteral(foldNull(new ToString(EMPTY, NULL))); } public void testNullFoldingDoesNotApplyOnLogicalExpressions() { - FoldNull rule = new FoldNull(); - Or or = new Or(EMPTY, NULL, TRUE); - assertEquals(or, rule.rule(or)); + assertEquals(or, foldNull(or)); or = new Or(EMPTY, NULL, NULL); - assertEquals(or, rule.rule(or)); + assertEquals(or, foldNull(or)); And and = new And(EMPTY, NULL, TRUE); - assertEquals(and, rule.rule(and)); + assertEquals(and, foldNull(and)); and = new And(EMPTY, NULL, NULL); - assertEquals(and, rule.rule(and)); + assertEquals(and, foldNull(and)); } @SuppressWarnings("unchecked") public void testNullFoldingDoesNotApplyOnAggregate() throws Exception { - FoldNull rule = new FoldNull(); - List> items = List.of(Max.class, Min.class); for (Class clazz : items) { Constructor ctor = clazz.getConstructor(Source.class, Expression.class); AggregateFunction conditionalFunction = ctor.newInstance(EMPTY, getFieldAttribute("a")); - assertEquals(conditionalFunction, rule.rule(conditionalFunction)); + assertEquals(conditionalFunction, foldNull(conditionalFunction)); conditionalFunction = ctor.newInstance(EMPTY, NULL); - assertEquals(NULL, rule.rule(conditionalFunction)); + assertEquals(NULL, foldNull(conditionalFunction)); } Avg avg = new Avg(EMPTY, getFieldAttribute("a")); - assertEquals(avg, rule.rule(avg)); + assertEquals(avg, foldNull(avg)); avg = new Avg(EMPTY, NULL); - assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(avg)); + assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(avg)); Count count = new Count(EMPTY, getFieldAttribute("a")); - assertEquals(count, rule.rule(count)); + assertEquals(count, foldNull(count)); count = new Count(EMPTY, NULL); - assertEquals(count, rule.rule(count)); + assertEquals(count, foldNull(count)); CountDistinct countd = new CountDistinct(EMPTY, getFieldAttribute("a"), getFieldAttribute("a")); - assertEquals(countd, rule.rule(countd)); + assertEquals(countd, foldNull(countd)); countd = new CountDistinct(EMPTY, NULL, NULL); - assertEquals(new Literal(EMPTY, null, LONG), rule.rule(countd)); + assertEquals(new Literal(EMPTY, null, LONG), foldNull(countd)); Median median = new Median(EMPTY, getFieldAttribute("a")); - assertEquals(median, rule.rule(median)); + assertEquals(median, foldNull(median)); median = new Median(EMPTY, NULL); - assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(median)); + assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(median)); MedianAbsoluteDeviation medianad = new MedianAbsoluteDeviation(EMPTY, getFieldAttribute("a")); - assertEquals(medianad, rule.rule(medianad)); + assertEquals(medianad, foldNull(medianad)); medianad = new MedianAbsoluteDeviation(EMPTY, NULL); - assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(medianad)); + assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(medianad)); Percentile percentile = new Percentile(EMPTY, getFieldAttribute("a"), getFieldAttribute("a")); - assertEquals(percentile, rule.rule(percentile)); + assertEquals(percentile, foldNull(percentile)); percentile = new Percentile(EMPTY, NULL, NULL); - assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(percentile)); + assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(percentile)); Sum sum = new Sum(EMPTY, getFieldAttribute("a")); - assertEquals(sum, rule.rule(sum)); + assertEquals(sum, foldNull(sum)); sum = new Sum(EMPTY, NULL); - assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(sum)); + assertEquals(new Literal(EMPTY, null, DOUBLE), foldNull(sum)); } public void testNullFoldableDoesNotApplyToIsNullAndNotNull() { - FoldNull rule = new FoldNull(); - DataType numericType = randomFrom(INTEGER, LONG, DOUBLE); DataType genericType = randomFrom(INTEGER, LONG, DOUBLE, UNSIGNED_LONG, KEYWORD, TEXT, GEO_POINT, GEO_SHAPE, VERSION, IP); List items = List.of( @@ -260,29 +255,26 @@ public void testNullFoldableDoesNotApplyToIsNullAndNotNull() { ); for (Expression item : items) { Expression isNull = new IsNull(EMPTY, item); - Expression transformed = rule.rule(isNull); + Expression transformed = foldNull(isNull); assertEquals(isNull, transformed); IsNotNull isNotNull = new IsNotNull(EMPTY, item); - transformed = rule.rule(isNotNull); + transformed = foldNull(isNotNull); assertEquals(isNotNull, transformed); } } public void testNullBucketGetsFolded() { - FoldNull foldNull = new FoldNull(); - assertEquals(NULL, foldNull.rule(new Bucket(EMPTY, NULL, NULL, NULL, NULL))); + assertEquals(NULL, foldNull(new Bucket(EMPTY, NULL, NULL, NULL, NULL))); } public void testNullCategorizeGroupingNotFolded() { - FoldNull foldNull = new FoldNull(); Categorize categorize = new Categorize(EMPTY, NULL); - assertEquals(categorize, foldNull.rule(categorize)); + assertEquals(categorize, foldNull(categorize)); } private void assertNullLiteral(Expression expression) { - assertEquals(Literal.class, expression.getClass()); - assertNull(expression.fold()); + assertNull(as(expression, Literal.class).value()); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java index 17e69e81444c5..1664e9f4653bb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/LiteralsOnTheRightTests.java @@ -15,6 +15,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.FIVE; import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; @@ -22,7 +23,7 @@ public class LiteralsOnTheRightTests extends ESTestCase { public void testLiteralsOnTheRight() { Alias a = new Alias(EMPTY, "a", new Literal(EMPTY, 10, INTEGER)); - Expression result = new LiteralsOnTheRight().rule(equalsOf(FIVE, a)); + Expression result = new LiteralsOnTheRight().rule(equalsOf(FIVE, a), unboundLogicalOptimizerContext()); assertTrue(result instanceof Equals); Equals eq = (Equals) result; assertEquals(a, eq.left()); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java index 55091653e75d4..a6c0d838b2c21 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateEqualsTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; @@ -37,11 +38,15 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; public class PropagateEqualsTests extends ESTestCase { + private Expression propagateEquals(BinaryLogic e) { + return new PropagateEquals().rule(e, unboundLogicalOptimizerContext()); + } // a == 1 AND a == 2 -> FALSE public void testDualEqualsConjunction() { @@ -49,8 +54,7 @@ public void testDualEqualsConjunction() { Equals eq1 = equalsOf(fa, ONE); Equals eq2 = equalsOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, eq2)); + Expression exp = propagateEquals(new And(EMPTY, eq1, eq2)); assertEquals(FALSE, exp); } @@ -60,8 +64,7 @@ public void testEliminateRangeByEqualsOutsideInterval() { Equals eq1 = equalsOf(fa, new Literal(EMPTY, 10, DataType.INTEGER)); Range r = rangeOf(fa, ONE, false, new Literal(EMPTY, 10, DataType.INTEGER), false); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); + Expression exp = propagateEquals(new And(EMPTY, eq1, r)); assertEquals(FALSE, exp); } @@ -71,8 +74,7 @@ public void testPropagateEquals_VarNeq3AndVarEq3() { NotEquals neq = notEqualsOf(fa, THREE); Equals eq = equalsOf(fa, THREE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, neq, eq)); + Expression exp = propagateEquals(new And(EMPTY, neq, eq)); assertEquals(FALSE, exp); } @@ -82,8 +84,7 @@ public void testPropagateEquals_VarNeq4AndVarEq3() { NotEquals neq = notEqualsOf(fa, FOUR); Equals eq = equalsOf(fa, THREE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, neq, eq)); + Expression exp = propagateEquals(new And(EMPTY, neq, eq)); assertEquals(Equals.class, exp.getClass()); assertEquals(eq, exp); } @@ -94,8 +95,7 @@ public void testPropagateEquals_VarEq2AndVarLt2() { Equals eq = equalsOf(fa, TWO); LessThan lt = lessThanOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); + Expression exp = propagateEquals(new And(EMPTY, eq, lt)); assertEquals(FALSE, exp); } @@ -105,8 +105,7 @@ public void testPropagateEquals_VarEq2AndVarLte2() { Equals eq = equalsOf(fa, TWO); LessThanOrEqual lt = lessThanOrEqualOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); + Expression exp = propagateEquals(new And(EMPTY, eq, lt)); assertEquals(eq, exp); } @@ -116,8 +115,7 @@ public void testPropagateEquals_VarEq2AndVarLte1() { Equals eq = equalsOf(fa, TWO); LessThanOrEqual lt = lessThanOrEqualOf(fa, ONE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); + Expression exp = propagateEquals(new And(EMPTY, eq, lt)); assertEquals(FALSE, exp); } @@ -127,8 +125,7 @@ public void testPropagateEquals_VarEq2AndVarGt2() { Equals eq = equalsOf(fa, TWO); GreaterThan gt = greaterThanOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gt)); + Expression exp = propagateEquals(new And(EMPTY, eq, gt)); assertEquals(FALSE, exp); } @@ -138,8 +135,7 @@ public void testPropagateEquals_VarEq2AndVarGte2() { Equals eq = equalsOf(fa, TWO); GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gte)); + Expression exp = propagateEquals(new And(EMPTY, eq, gte)); assertEquals(eq, exp); } @@ -149,8 +145,7 @@ public void testPropagateEquals_VarEq2AndVarLt3() { Equals eq = equalsOf(fa, TWO); GreaterThan gt = greaterThanOf(fa, THREE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gt)); + Expression exp = propagateEquals(new And(EMPTY, eq, gt)); assertEquals(FALSE, exp); } @@ -162,9 +157,8 @@ public void testPropagateEquals_VarEq2AndVarLt3AndVarGt1AndVarNeq4() { GreaterThan gt = greaterThanOf(fa, ONE); NotEquals neq = notEqualsOf(fa, FOUR); - PropagateEquals rule = new PropagateEquals(); Expression and = Predicates.combineAnd(asList(eq, lt, gt, neq)); - Expression exp = rule.rule((And) and); + Expression exp = propagateEquals((And) and); assertEquals(eq, exp); } @@ -176,9 +170,8 @@ public void testPropagateEquals_VarEq2AndVarRangeGt1Lt3AndVarGt0AndVarNeq4() { GreaterThan gt = greaterThanOf(fa, new Literal(EMPTY, 0, DataType.INTEGER)); NotEquals neq = notEqualsOf(fa, FOUR); - PropagateEquals rule = new PropagateEquals(); Expression and = Predicates.combineAnd(asList(eq, range, gt, neq)); - Expression exp = rule.rule((And) and); + Expression exp = propagateEquals((And) and); assertEquals(eq, exp); } @@ -188,8 +181,7 @@ public void testPropagateEquals_VarEq2OrVarGt1() { Equals eq = equalsOf(fa, TWO); GreaterThan gt = greaterThanOf(fa, ONE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, gt)); + Expression exp = propagateEquals(new Or(EMPTY, eq, gt)); assertEquals(gt, exp); } @@ -199,8 +191,7 @@ public void testPropagateEquals_VarEq2OrVarGte2() { Equals eq = equalsOf(fa, TWO); GreaterThan gt = greaterThanOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, gt)); + Expression exp = propagateEquals(new Or(EMPTY, eq, gt)); assertEquals(GreaterThanOrEqual.class, exp.getClass()); GreaterThanOrEqual gte = (GreaterThanOrEqual) exp; assertEquals(TWO, gte.right()); @@ -212,8 +203,7 @@ public void testPropagateEquals_VarEq2OrVarLt3() { Equals eq = equalsOf(fa, TWO); LessThan lt = lessThanOf(fa, THREE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, lt)); + Expression exp = propagateEquals(new Or(EMPTY, eq, lt)); assertEquals(lt, exp); } @@ -223,8 +213,7 @@ public void testPropagateEquals_VarEq3OrVarLt3() { Equals eq = equalsOf(fa, THREE); LessThan lt = lessThanOf(fa, THREE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, lt)); + Expression exp = propagateEquals(new Or(EMPTY, eq, lt)); assertEquals(LessThanOrEqual.class, exp.getClass()); LessThanOrEqual lte = (LessThanOrEqual) exp; assertEquals(THREE, lte.right()); @@ -236,8 +225,7 @@ public void testPropagateEquals_VarEq2OrVarRangeGt1Lt3() { Equals eq = equalsOf(fa, TWO); Range range = rangeOf(fa, ONE, false, THREE, false); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); + Expression exp = propagateEquals(new Or(EMPTY, eq, range)); assertEquals(range, exp); } @@ -247,8 +235,7 @@ public void testPropagateEquals_VarEq2OrVarRangeGt2Lt3() { Equals eq = equalsOf(fa, TWO); Range range = rangeOf(fa, TWO, false, THREE, false); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); + Expression exp = propagateEquals(new Or(EMPTY, eq, range)); assertEquals(Range.class, exp.getClass()); Range r = (Range) exp; assertEquals(TWO, r.lower()); @@ -263,8 +250,7 @@ public void testPropagateEquals_VarEq3OrVarRangeGt2Lt3() { Equals eq = equalsOf(fa, THREE); Range range = rangeOf(fa, TWO, false, THREE, false); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); + Expression exp = propagateEquals(new Or(EMPTY, eq, range)); assertEquals(Range.class, exp.getClass()); Range r = (Range) exp; assertEquals(TWO, r.lower()); @@ -279,8 +265,7 @@ public void testPropagateEquals_VarEq2OrVarNeq2() { Equals eq = equalsOf(fa, TWO); NotEquals neq = notEqualsOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, neq)); + Expression exp = propagateEquals(new Or(EMPTY, eq, neq)); assertEquals(TRUE, exp); } @@ -290,8 +275,7 @@ public void testPropagateEquals_VarEq2OrVarNeq5() { Equals eq = equalsOf(fa, TWO); NotEquals neq = notEqualsOf(fa, FIVE); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, neq)); + Expression exp = propagateEquals(new Or(EMPTY, eq, neq)); assertEquals(NotEquals.class, exp.getClass()); NotEquals ne = (NotEquals) exp; assertEquals(FIVE, ne.right()); @@ -305,8 +289,7 @@ public void testPropagateEquals_VarEq2OrVarRangeGt3Lt4OrVarGt2OrVarNe2() { GreaterThan gt = greaterThanOf(fa, TWO); NotEquals neq = notEqualsOf(fa, TWO); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule((Or) Predicates.combineOr(asList(eq, range, neq, gt))); + Expression exp = propagateEquals((Or) Predicates.combineOr(asList(eq, range, neq, gt))); assertEquals(TRUE, exp); } @@ -317,8 +300,7 @@ public void testPropagateEquals_ignoreDateTimeFields() { Equals eq2 = equalsOf(fa, TWO); And and = new And(EMPTY, eq1, eq2); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(and); + Expression exp = propagateEquals(and); assertEquals(and, exp); } @@ -328,8 +310,7 @@ public void testEliminateRangeByEqualsInInterval() { Equals eq1 = equalsOf(fa, ONE); Range r = rangeOf(fa, ONE, true, new Literal(EMPTY, 10, DataType.INTEGER), false); - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); + Expression exp = propagateEquals(new And(EMPTY, eq1, r)); assertEquals(eq1, exp); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java index d1d6a7fbaa208..d35890e5b56bb 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PropagateNullableTests.java @@ -31,11 +31,20 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; public class PropagateNullableTests extends ESTestCase { + private Expression propagateNullable(And e) { + return new PropagateNullable().rule(e, unboundLogicalOptimizerContext()); + } + + private LogicalPlan propagateNullable(LogicalPlan p) { + return new PropagateNullable().apply(p, unboundLogicalOptimizerContext()); + } + private Literal nullOf(DataType dataType) { return new Literal(Source.EMPTY, null, dataType); } @@ -45,7 +54,7 @@ public void testIsNullAndNotNull() { FieldAttribute fa = getFieldAttribute(); And and = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); - assertEquals(FALSE, new PropagateNullable().rule(and)); + assertEquals(FALSE, propagateNullable(and)); } // a IS NULL AND b IS NOT NULL AND c IS NULL AND d IS NOT NULL AND e IS NULL AND a IS NOT NULL => false @@ -58,7 +67,7 @@ public void testIsNullAndNotNullMultiField() { And and = new And(EMPTY, andOne, new And(EMPTY, andThree, andTwo)); - assertEquals(FALSE, new PropagateNullable().rule(and)); + assertEquals(FALSE, propagateNullable(and)); } // a IS NULL AND a > 1 => a IS NULL AND false @@ -67,7 +76,7 @@ public void testIsNullAndComparison() { IsNull isNull = new IsNull(EMPTY, fa); And and = new And(EMPTY, isNull, greaterThanOf(fa, ONE)); - assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), new PropagateNullable().rule(and)); + assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), propagateNullable(and)); } // a IS NULL AND b < 1 AND c < 1 AND a < 1 => a IS NULL AND b < 1 AND c < 1 => a IS NULL AND b < 1 AND c < 1 @@ -79,7 +88,7 @@ public void testIsNullAndMultipleComparison() { And and = new And(EMPTY, isNull, nestedAnd); And top = new And(EMPTY, and, lessThanOf(fa, ONE)); - Expression optimized = new PropagateNullable().rule(top); + Expression optimized = propagateNullable(top); Expression expected = new And(EMPTY, and, nullOf(BOOLEAN)); assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); } @@ -97,7 +106,7 @@ public void testIsNullAndDeeplyNestedExpression() { Expression kept = new And(EMPTY, isNull, lessThanOf(getFieldAttribute("b"), THREE)); And and = new And(EMPTY, nullified, kept); - Expression optimized = new PropagateNullable().rule(and); + Expression optimized = propagateNullable(and); Expression expected = new And(EMPTY, new And(EMPTY, nullOf(BOOLEAN), nullOf(BOOLEAN)), kept); assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); @@ -110,13 +119,13 @@ public void testIsNullInDisjunction() { Or or = new Or(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); Filter dummy = new Filter(EMPTY, relation(), or); - LogicalPlan transformed = new PropagateNullable().apply(dummy); + LogicalPlan transformed = propagateNullable(dummy); assertSame(dummy, transformed); assertEquals(or, ((Filter) transformed).condition()); or = new Or(EMPTY, new IsNull(EMPTY, fa), greaterThanOf(fa, ONE)); dummy = new Filter(EMPTY, relation(), or); - transformed = new PropagateNullable().apply(dummy); + transformed = propagateNullable(dummy); assertSame(dummy, transformed); assertEquals(or, ((Filter) transformed).condition()); } @@ -129,7 +138,7 @@ public void testIsNullDisjunction() { Or or = new Or(EMPTY, isNull, greaterThanOf(fa, THREE)); And and = new And(EMPTY, new Add(EMPTY, fa, ONE), or); - assertEquals(and, new PropagateNullable().rule(and)); + assertEquals(and, propagateNullable(and)); } public void testDoNotOptimizeIsNullAndMultipleComparisonWithConstants() { @@ -141,7 +150,7 @@ public void testDoNotOptimizeIsNullAndMultipleComparisonWithConstants() { And aIsNull_AND_bLT1_AND_cLT1 = new And(EMPTY, aIsNull, bLT1_AND_cLT1); And aIsNull_AND_bLT1_AND_cLT1_AND_aLT1 = new And(EMPTY, aIsNull_AND_bLT1_AND_cLT1, lessThanOf(a, ONE)); - Expression optimized = new PropagateNullable().rule(aIsNull_AND_bLT1_AND_cLT1_AND_aLT1); + Expression optimized = propagateNullable(aIsNull_AND_bLT1_AND_cLT1_AND_aLT1); Literal nullLiteral = new Literal(EMPTY, null, BOOLEAN); assertEquals(asList(aIsNull, nullLiteral, nullLiteral, nullLiteral), Predicates.splitAnd(optimized)); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java index c7206c6971bde..b9ffc39e5e130 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceRegexMatchTests.java @@ -10,8 +10,10 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RegexMatch; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; @@ -20,16 +22,20 @@ import static java.util.Arrays.asList; import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; public class ReplaceRegexMatchTests extends ESTestCase { + private Expression replaceRegexMatch(RegexMatch e) { + return new ReplaceRegexMatch().rule(e, unboundLogicalOptimizerContext()); + } public void testMatchAllWildcardLikeToExist() { for (String s : asList("*", "**", "***")) { WildcardPattern pattern = new WildcardPattern(s); FieldAttribute fa = getFieldAttribute(); WildcardLike l = new WildcardLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); + Expression e = replaceRegexMatch(l); assertEquals(IsNotNull.class, e.getClass()); IsNotNull inn = (IsNotNull) e; assertEquals(fa, inn.field()); @@ -40,7 +46,7 @@ public void testMatchAllRLikeToExist() { RLikePattern pattern = new RLikePattern(".*"); FieldAttribute fa = getFieldAttribute(); RLike l = new RLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); + Expression e = replaceRegexMatch(l); assertEquals(IsNotNull.class, e.getClass()); IsNotNull inn = (IsNotNull) e; assertEquals(fa, inn.field()); @@ -51,11 +57,11 @@ public void testExactMatchWildcardLike() { WildcardPattern pattern = new WildcardPattern(s); FieldAttribute fa = getFieldAttribute(); WildcardLike l = new WildcardLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); + Expression e = replaceRegexMatch(l); assertEquals(Equals.class, e.getClass()); Equals eq = (Equals) e; assertEquals(fa, eq.left()); - assertEquals(s.replace("\\", StringUtils.EMPTY), eq.right().fold()); + assertEquals(s.replace("\\", StringUtils.EMPTY), eq.right().fold(FoldContext.small())); } } @@ -63,11 +69,11 @@ public void testExactMatchRLike() { RLikePattern pattern = new RLikePattern("abc"); FieldAttribute fa = getFieldAttribute(); RLike l = new RLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); + Expression e = replaceRegexMatch(l); assertEquals(Equals.class, e.getClass()); Equals eq = (Equals) e; assertEquals(fa, eq.left()); - assertEquals("abc", eq.right().fold()); + assertEquals("abc", eq.right().fold(FoldContext.small())); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java index 2429bcb1a1b04..90c8ae1032325 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/physical/local/PushTopNToSourceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.Nullability; @@ -417,7 +418,7 @@ private static void assertNoPushdownSort(TestPhysicalPlanBuilder builder, String private static PhysicalPlan pushTopNToSource(TopNExec topNExec) { var configuration = EsqlTestUtils.configuration("from test"); - var ctx = new LocalPhysicalOptimizerContext(configuration, SearchStats.EMPTY); + var ctx = new LocalPhysicalOptimizerContext(configuration, FoldContext.small(), SearchStats.EMPTY); var pushTopNToSource = new PushTopNToSource(); return pushTopNToSource.rule(topNExec, ctx); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java index 710637c05a900..85d4017b166fa 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar; @@ -617,16 +618,14 @@ public void testSimplifyInWithSingleElementList() { Equals eq = (Equals) e; assertThat(eq.left(), instanceOf(UnresolvedAttribute.class)); assertThat(((UnresolvedAttribute) eq.left()).name(), equalTo("a")); - assertThat(eq.right(), instanceOf(Literal.class)); - assertThat(eq.right().fold(), equalTo(1)); + assertThat(as(eq.right(), Literal.class).value(), equalTo(1)); e = whereExpression("1 IN (a)"); assertThat(e, instanceOf(Equals.class)); eq = (Equals) e; assertThat(eq.right(), instanceOf(UnresolvedAttribute.class)); assertThat(((UnresolvedAttribute) eq.right()).name(), equalTo("a")); - assertThat(eq.left(), instanceOf(Literal.class)); - assertThat(eq.left().fold(), equalTo(1)); + assertThat(eq.left().fold(FoldContext.small()), equalTo(1)); e = whereExpression("1 NOT IN (a)"); assertThat(e, instanceOf(Not.class)); @@ -635,9 +634,7 @@ public void testSimplifyInWithSingleElementList() { eq = (Equals) e; assertThat(eq.right(), instanceOf(UnresolvedAttribute.class)); assertThat(((UnresolvedAttribute) eq.right()).name(), equalTo("a")); - assertThat(eq.left(), instanceOf(Literal.class)); - assertThat(eq.left().fold(), equalTo(1)); - + assertThat(eq.left().fold(FoldContext.small()), equalTo(1)); } private Expression whereExpression(String e) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java index 0ebec962b79a6..a4712ae77b5d8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; @@ -1132,51 +1133,51 @@ public void testInputParams() { assertThat(field.name(), is("x")); assertThat(field, instanceOf(Alias.class)); Alias alias = (Alias) field; - assertThat(alias.child().fold(), is(1)); + assertThat(alias.child().fold(FoldContext.small()), is(1)); field = row.fields().get(1); assertThat(field.name(), is("y")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold(), is("2")); + assertThat(alias.child().fold(FoldContext.small()), is("2")); field = row.fields().get(2); assertThat(field.name(), is("a")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold(), is("2 days")); + assertThat(alias.child().fold(FoldContext.small()), is("2 days")); field = row.fields().get(3); assertThat(field.name(), is("b")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold(), is("4 hours")); + assertThat(alias.child().fold(FoldContext.small()), is("4 hours")); field = row.fields().get(4); assertThat(field.name(), is("c")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold().getClass(), is(String.class)); - assertThat(alias.child().fold().toString(), is("1.2.3")); + assertThat(alias.child().fold(FoldContext.small()).getClass(), is(String.class)); + assertThat(alias.child().fold(FoldContext.small()).toString(), is("1.2.3")); field = row.fields().get(5); assertThat(field.name(), is("d")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold().getClass(), is(String.class)); - assertThat(alias.child().fold().toString(), is("127.0.0.1")); + assertThat(alias.child().fold(FoldContext.small()).getClass(), is(String.class)); + assertThat(alias.child().fold(FoldContext.small()).toString(), is("127.0.0.1")); field = row.fields().get(6); assertThat(field.name(), is("e")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold(), is(9)); + assertThat(alias.child().fold(FoldContext.small()), is(9)); field = row.fields().get(7); assertThat(field.name(), is("f")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold(), is(11)); + assertThat(alias.child().fold(FoldContext.small()), is(11)); } public void testMissingInputParams() { @@ -1193,13 +1194,13 @@ public void testNamedParams() { assertThat(field.name(), is("x")); assertThat(field, instanceOf(Alias.class)); Alias alias = (Alias) field; - assertThat(alias.child().fold(), is(1)); + assertThat(alias.child().fold(FoldContext.small()), is(1)); field = row.fields().get(1); assertThat(field.name(), is("y")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold(), is(1)); + assertThat(alias.child().fold(FoldContext.small()), is(1)); } public void testInvalidNamedParams() { @@ -1240,13 +1241,13 @@ public void testPositionalParams() { assertThat(field.name(), is("x")); assertThat(field, instanceOf(Alias.class)); Alias alias = (Alias) field; - assertThat(alias.child().fold(), is(1)); + assertThat(alias.child().fold(FoldContext.small()), is(1)); field = row.fields().get(1); assertThat(field.name(), is("y")); assertThat(field, instanceOf(Alias.class)); alias = (Alias) field; - assertThat(alias.child().fold(), is(1)); + assertThat(alias.child().fold(FoldContext.small()), is(1)); } public void testInvalidPositionalParams() { @@ -2060,7 +2061,7 @@ private void assertStringAsLookupIndexPattern(String string, String statement) { var plan = statement(statement); var lookup = as(plan, Lookup.class); var tableName = as(lookup.tableName(), Literal.class); - assertThat(tableName.fold(), equalTo(string)); + assertThat(tableName.fold(FoldContext.small()), equalTo(string)); } public void testIdPatternUnquoted() throws Exception { @@ -2128,7 +2129,7 @@ public void testLookup() { var plan = statement(query); var lookup = as(plan, Lookup.class); var tableName = as(lookup.tableName(), Literal.class); - assertThat(tableName.fold(), equalTo("t")); + assertThat(tableName.fold(FoldContext.small()), equalTo("t")); assertThat(lookup.matchFields(), hasSize(1)); var matchField = as(lookup.matchFields().get(0), UnresolvedAttribute.class); assertThat(matchField.name(), equalTo("j")); @@ -2309,7 +2310,7 @@ public void testMatchOperatorConstantQueryString() { var match = (Match) filter.condition(); var matchField = (UnresolvedAttribute) match.field(); assertThat(matchField.name(), equalTo("field")); - assertThat(match.query().fold(), equalTo("value")); + assertThat(match.query().fold(FoldContext.small()), equalTo("value")); } public void testInvalidMatchOperator() { @@ -2344,7 +2345,7 @@ public void testMatchFunctionFieldCasting() { var toInteger = (ToInteger) function.children().get(0); var matchField = (UnresolvedAttribute) toInteger.field(); assertThat(matchField.name(), equalTo("field")); - assertThat(function.children().get(1).fold(), equalTo("value")); + assertThat(function.children().get(1).fold(FoldContext.small()), equalTo("value")); } public void testMatchOperatorFieldCasting() { @@ -2354,6 +2355,6 @@ public void testMatchOperatorFieldCasting() { var toInteger = (ToInteger) match.field(); var matchField = (UnresolvedAttribute) toInteger.field(); assertThat(matchField.name(), equalTo("field")); - assertThat(match.query().fold(), equalTo("value")); + assertThat(match.query().fold(FoldContext.small()), equalTo("value")); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java index a254207865ad5..2f47a672a68d0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; @@ -42,7 +43,7 @@ public void testTransformWithExpressionTopLevel() throws Exception { assertEquals(Limit.class, transformed.getClass()); Limit l = (Limit) transformed; - assertEquals(24, l.limit().fold()); + assertEquals(24, l.limit().fold(FoldContext.small())); } public void testTransformWithExpressionTree() throws Exception { @@ -53,7 +54,7 @@ public void testTransformWithExpressionTree() throws Exception { assertEquals(OrderBy.class, transformed.getClass()); OrderBy order = (OrderBy) transformed; assertEquals(Limit.class, order.child().getClass()); - assertEquals(24, ((Limit) order.child()).limit().fold()); + assertEquals(24, ((Limit) order.child()).limit().fold(FoldContext.small())); } public void testTransformWithExpressionTopLevelInCollection() throws Exception { @@ -83,12 +84,12 @@ public void testForEachWithExpressionTopLevel() throws Exception { List list = new ArrayList<>(); project.forEachExpression(Literal.class, l -> { - if (l.fold().equals(42)) { - list.add(l.fold()); + if (l.value().equals(42)) { + list.add(l.value()); } }); - assertEquals(singletonList(one.child().fold()), list); + assertEquals(singletonList(one.child().fold(FoldContext.small())), list); } public void testForEachWithExpressionTree() throws Exception { @@ -97,12 +98,12 @@ public void testForEachWithExpressionTree() throws Exception { List list = new ArrayList<>(); o.forEachExpressionDown(Literal.class, l -> { - if (l.fold().equals(42)) { - list.add(l.fold()); + if (l.value().equals(42)) { + list.add(l.value()); } }); - assertEquals(singletonList(limit.limit().fold()), list); + assertEquals(singletonList(limit.limit().fold(FoldContext.small())), list); } public void testForEachWithExpressionTopLevelInCollection() throws Exception { @@ -129,12 +130,12 @@ public void testForEachWithExpressionTreeInCollection() throws Exception { List list = new ArrayList<>(); project.forEachExpression(Literal.class, l -> { - if (l.fold().equals(42)) { - list.add(l.fold()); + if (l.value().equals(42)) { + list.add(l.value()); } }); - assertEquals(singletonList(one.child().fold()), list); + assertEquals(singletonList(one.child().fold(FoldContext.small())), list); } public void testPlanExpressions() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java index 5a7547d011c0f..e2eb05b0c14d3 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; @@ -145,7 +146,7 @@ public void testEvaluatorSuppliers() { lb.append(LONG); Layout layout = lb.build(); - var supplier = EvalMapper.toEvaluator(expression, layout); + var supplier = EvalMapper.toEvaluator(FoldContext.small(), expression, layout); EvalOperator.ExpressionEvaluator evaluator1 = supplier.get(driverContext()); EvalOperator.ExpressionEvaluator evaluator2 = supplier.get(driverContext()); assertNotNull(evaluator1); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java index 55f32d07fc2cb..4191f42f08237 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/FilterTests.java @@ -30,7 +30,6 @@ import org.elasticsearch.xpack.esql.index.IndexResolution; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; -import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer; @@ -52,6 +51,7 @@ import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; import static org.elasticsearch.xpack.esql.core.util.Queries.Clause.FILTER; @@ -78,7 +78,7 @@ public static void init() { Map mapping = loadMapping("mapping-basic.json"); EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); IndexResolution getIndexResult = IndexResolution.valid(test); - logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext()); physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(EsqlTestUtils.TEST_CFG)); mapper = new Mapper(); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index 5d8da21c6faad..a1648c67d9bd4 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -32,6 +32,7 @@ import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -83,6 +84,7 @@ public void closeIndex() throws IOException { public void testLuceneSourceOperatorHugeRowSize() throws IOException { int estimatedRowSize = randomEstimatedRowSize(estimatedRowSizeIsHuge); LocalExecutionPlanner.LocalExecutionPlan plan = planner().plan( + FoldContext.small(), new EsQueryExec(Source.EMPTY, index(), IndexMode.STANDARD, List.of(), null, null, null, estimatedRowSize) ); assertThat(plan.driverFactories.size(), lessThanOrEqualTo(pragmas.taskConcurrency())); @@ -98,6 +100,7 @@ public void testLuceneTopNSourceOperator() throws IOException { EsQueryExec.FieldSort sort = new EsQueryExec.FieldSort(sortField, Order.OrderDirection.ASC, Order.NullsPosition.LAST); Literal limit = new Literal(Source.EMPTY, 10, DataType.INTEGER); LocalExecutionPlanner.LocalExecutionPlan plan = planner().plan( + FoldContext.small(), new EsQueryExec(Source.EMPTY, index(), IndexMode.STANDARD, List.of(), null, limit, List.of(sort), estimatedRowSize) ); assertThat(plan.driverFactories.size(), lessThanOrEqualTo(pragmas.taskConcurrency())); @@ -113,6 +116,7 @@ public void testLuceneTopNSourceOperatorDistanceSort() throws IOException { EsQueryExec.GeoDistanceSort sort = new EsQueryExec.GeoDistanceSort(sortField, Order.OrderDirection.ASC, 1, -1); Literal limit = new Literal(Source.EMPTY, 10, DataType.INTEGER); LocalExecutionPlanner.LocalExecutionPlan plan = planner().plan( + FoldContext.small(), new EsQueryExec(Source.EMPTY, index(), IndexMode.STANDARD, List.of(), null, limit, List.of(sort), estimatedRowSize) ); assertThat(plan.driverFactories.size(), lessThanOrEqualTo(pragmas.taskConcurrency())); @@ -187,7 +191,7 @@ private EsPhysicalOperationProviders esPhysicalOperationProviders() throws IOExc ); } releasables.add(searcher); - return new EsPhysicalOperationProviders(shardContexts, null); + return new EsPhysicalOperationProviders(FoldContext.small(), shardContexts, null); } private IndexReader reader() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index 01dd4db123ee2..628737aa36c6c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.esql.TestBlockFactory; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.MultiTypeEsField; import org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes; @@ -71,13 +72,13 @@ public class TestPhysicalOperationProviders extends AbstractPhysicalOperationProviders { private final List indexPages; - private TestPhysicalOperationProviders(List indexPages, AnalysisRegistry analysisRegistry) { - super(analysisRegistry); + private TestPhysicalOperationProviders(FoldContext foldContext, List indexPages, AnalysisRegistry analysisRegistry) { + super(foldContext, analysisRegistry); this.indexPages = indexPages; } - public static TestPhysicalOperationProviders create(List indexPages) throws IOException { - return new TestPhysicalOperationProviders(indexPages, createAnalysisRegistry()); + public static TestPhysicalOperationProviders create(FoldContext foldContext, List indexPages) throws IOException { + return new TestPhysicalOperationProviders(foldContext, indexPages, createAnalysisRegistry()); } public record IndexPage(String index, Page page, List columnNames) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java index f2a619f0dbd89..f3b1d84e507a5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ClusterRequestTests.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.index.IndexResolution; -import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; @@ -39,10 +38,10 @@ import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.hamcrest.Matchers.equalTo; @@ -187,7 +186,7 @@ static LogicalPlan parse(String query) { Map mapping = loadMapping("mapping-basic.json"); EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); IndexResolution getIndexResult = IndexResolution.valid(test); - var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG)); + var logicalOptimizer = new LogicalPlanOptimizer(unboundLogicalOptimizerContext()); var analyzer = new Analyzer( new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), TEST_VERIFIER diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java index 2cc733c2ea2e3..fac3495697da8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestSerializationTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.index.EsIndex; @@ -289,7 +290,7 @@ static LogicalPlan parse(String query) { Map mapping = loadMapping("mapping-basic.json"); EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); IndexResolution getIndexResult = IndexResolution.valid(test); - var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG)); + var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG, FoldContext.small())); var analyzer = new Analyzer( new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), TEST_VERIFIER diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java index 539cd0314a4d1..a3c5cd9168b4f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.xpack.esql.action.EsqlQueryRequest; import org.elasticsearch.xpack.esql.action.EsqlResolveFieldsAction; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; +import org.elasticsearch.xpack.esql.core.expression.FoldContext; import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver; import org.elasticsearch.xpack.esql.execution.PlanExecutor; import org.elasticsearch.xpack.esql.session.EsqlSession; @@ -119,6 +120,7 @@ public void testFailedMetric() { request, randomAlphaOfLength(10), EsqlTestUtils.TEST_CFG, + FoldContext.small(), enrichResolver, new EsqlExecutionInfo(randomBoolean()), groupIndicesByCluster, @@ -149,6 +151,7 @@ public void onFailure(Exception e) { request, randomAlphaOfLength(10), EsqlTestUtils.TEST_CFG, + FoldContext.small(), enrichResolver, new EsqlExecutionInfo(randomBoolean()), groupIndicesByCluster,