diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java index f7bee9e55461e..05e22cbcaa124 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java @@ -107,6 +107,40 @@ public void testLengthNotPushedToText() throws IOException { ); } + public void testVCosine() throws IOException { + test( + justType("dense_vector"), + b -> b.startArray("test").value(128).value(128).value(0).endArray(), + "| EVAL test = V_COSINE(test, [0, 255, 255])", + matchesList().item(0.5), + matchesMap().entry("test:column_at_a_time:FloatDenseVectorFromDocValues.Normalized.V_COSINE", 1) + ); + } + + public void testVHammingToByte() throws IOException { + test( + b -> b.startObject("test").field("type", "dense_vector").field("element_type", "byte").endObject(), + b -> b.startArray("test").value(100).value(100).value(0).endArray(), + "| EVAL test = V_HAMMING(test, [0, 100, 100])", + matchesList().item(6.0), + matchesMap().entry("test:column_at_a_time:ByteDenseVectorFromDocValues.V_HAMMING", 1) + ); + } + + public void testVHammingToBit() throws IOException { + test( + b -> b.startObject("test").field("type", "dense_vector").field("element_type", "bit").endObject(), + b -> b.startArray("test").value(100).value(100).value(0).endArray(), + "| EVAL test = V_HAMMING(test, [0, 100, 100])", + matchesList().item(6.0), + matchesMap().entry("test:column_at_a_time:BitDenseVectorFromDocValues.V_HAMMING", 1) + ); + } + + // + // Tests for more complex shapes. + // + /** * Tests {@code LENGTH} on a field that comes from a {@code LOOKUP JOIN}. */ @@ -120,8 +154,16 @@ public void testLengthNotPushedToLookupJoinKeyword() throws IOException { | EVAL test = LENGTH(test) """, matchesList().item(1), - matchesMap().entry("main_matching:column_at_a_time:BytesRefsFromOrds.Singleton", 1), // - sig -> {} + matchesMap().entry("main_matching:column_at_a_time:BytesRefsFromOrds.Singleton", 1), + sig -> assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") // the real work is here, checkOperatorProfile checks the status + .item("LookupOperator") + .item("EvalOperator") // this one just renames the field + .item("AggregationOperator") + .item("ExchangeSinkOperator") + ) ); } @@ -131,7 +173,6 @@ public void testLengthNotPushedToLookupJoinKeyword() throws IOException { * querying it. */ public void testLengthNotPushedToLookupJoinKeywordSameName() throws IOException { - assumeFalse("fix in 137679 - we push to the index but that's just wrong!", true); String value = "v".repeat(between(0, 256)); initLookupIndex(); test(b -> { @@ -144,40 +185,197 @@ public void testLengthNotPushedToLookupJoinKeywordSameName() throws IOException | LOOKUP JOIN lookup ON matching == main_matching | EVAL test = LENGTH(test) """, - matchesList().item(1), // <--- This is incorrectly returning value.length() + matchesList().item(1), matchesMap().entry("main_matching:column_at_a_time:BytesRefsFromOrds.Singleton", 1), - // ^^^^ This is incorrectly returning test:column_at_a_time:Utf8CodePointsFromOrds.Singleton - sig -> {} + sig -> assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") // the real work is here, checkOperatorProfile checks the status + .item("LookupOperator") + .item("EvalOperator") // this one just renames the field + .item("AggregationOperator") + .item("ExchangeSinkOperator") + ) ); } - public void testVCosine() throws IOException { + /** + * Tests {@code LENGTH} on a field that comes from a {@code LOOKUP JOIN}. + */ + public void testLengthPushedInsideInlineStats() throws IOException { + String value = "v".repeat(between(0, 256)); test( - justType("dense_vector"), - b -> b.startArray("test").value(128).value(128).value(0).endArray(), - "| EVAL test = V_COSINE(test, [0, 255, 255])", - matchesList().item(0.5), - matchesMap().entry("test:column_at_a_time:FloatDenseVectorFromDocValues.Normalized.V_COSINE", 1) + justType("keyword"), + b -> b.field("test", value), + """ + | INLINE STATS max_length = MAX(LENGTH(test)) + | EVAL test = LENGTH(test) + | WHERE test == max_length + """, + matchesList().item(value.length()), + matchesMap().entry("test:column_at_a_time:Utf8CodePointsFromOrds.Singleton", 1), + sig -> { + // There are two data node plans, one for each phase. + if (sig.contains("FilterOperator")) { + assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") // the real work is here, checkOperatorProfile checks the status + .item("FilterOperator") + .item("EvalOperator") // this one just renames the field + .item("AggregationOperator") + .item("ExchangeSinkOperator") + ); + } else { + assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") // the real work is here, checkOperatorProfile checks the status + .item("EvalOperator") // this one just renames the field + .item("AggregationOperator") + .item("ExchangeSinkOperator") + ); + } + } ); } - public void testVHammingToByte() throws IOException { + /** + * Tests {@code LENGTH} on a field that comes from a {@code LOOKUP JOIN}. + */ + public void testLengthNotPushedToInlineStatsResults() throws IOException { + String value = "v".repeat(between(0, 256)); + test(justType("keyword"), b -> b.field("test", value), """ + | INLINE STATS test2 = VALUES(test) + | EVAL test = LENGTH(test2) + """, matchesList().item(value.length()), matchesMap().entry("test:column_at_a_time:BytesRefsFromOrds.Singleton", 1), sig -> { + // There are two data node plans, one for each phase. + if (sig.contains("EvalOperator")) { + assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("EvalOperator") // The second phase of the INLINE STATS + .item("AggregationOperator") + .item("ExchangeSinkOperator") + ); + } else { + assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") + .item("AggregationOperator") + .item("ExchangeSinkOperator") + ); + } + }); + } + + /** + * Tests {@code LENGTH} on a field that comes from a {@code LOOKUP JOIN}. + */ + public void testLengthNotPushedToGroupedInlineStatsResults() throws IOException { + String value = "v".repeat(between(0, 256)); + CheckedConsumer mapping = b -> { + b.startObject("test").field("type", "keyword").endObject(); + b.startObject("group").field("type", "keyword").endObject(); + }; + test(mapping, b -> b.field("test", value).field("group", "g"), """ + | INLINE STATS test2 = VALUES(test) BY group + | EVAL test = LENGTH(test2) + """, matchesList().item(value.length()), matchesMap().extraOk(), sig -> { + // There are two data node plans, one for each phase. + if (sig.contains("EvalOperator")) { + assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") + .item("RowInTableLookup") + .item("ColumnLoad") + .item("ProjectOperator") + .item("EvalOperator") + .item("AggregationOperator") + .item("ExchangeSinkOperator") + ); + } else { + assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") + .item("HashAggregationOperator") + .item("ExchangeSinkOperator") + ); + } + }); + } + + /** + * LENGTH not pushed when on a fork branch. + */ + public void testLengthNotPushedToFork() throws IOException { + String value = "v".repeat(between(0, 256)); test( - b -> b.startObject("test").field("type", "dense_vector").field("element_type", "byte").endObject(), - b -> b.startArray("test").value(100).value(100).value(0).endArray(), - "| EVAL test = V_HAMMING(test, [0, 100, 100])", - matchesList().item(6.0), - matchesMap().entry("test:column_at_a_time:ByteDenseVectorFromDocValues.V_HAMMING", 1) + justType("keyword"), + b -> b.field("test", value), + """ + | FORK + (EVAL test = LENGTH(test) + 1) + (EVAL test = LENGTH(test) + 2) + """, + matchesList().item(List.of(value.length() + 1, value.length() + 2)), + matchesMap().entry("test:column_at_a_time:BytesRefsFromOrds.Singleton", 1), + sig -> assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") + .item("ProjectOperator") + .item("ExchangeSinkOperator") + ) ); } - public void testVHammingToBit() throws IOException { + public void testLengthNotPushedBeforeFork() throws IOException { + String value = "v".repeat(between(0, 256)); test( - b -> b.startObject("test").field("type", "dense_vector").field("element_type", "bit").endObject(), - b -> b.startArray("test").value(100).value(100).value(0).endArray(), - "| EVAL test = V_HAMMING(test, [0, 100, 100])", - matchesList().item(6.0), - matchesMap().entry("test:column_at_a_time:BitDenseVectorFromDocValues.V_HAMMING", 1) + justType("keyword"), + b -> b.field("test", value), + """ + | EVAL test = LENGTH(test) + | FORK + (EVAL j = 1) + (EVAL j = 2) + """, + matchesList().item(value.length()), + matchesMap().entry("test:column_at_a_time:BytesRefsFromOrds.Singleton", 1), + sig -> assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") + .item("ProjectOperator") + .item("ExchangeSinkOperator") + ) + ); + } + + public void testLengthNotPushedAfterFork() throws IOException { + String value = "v".repeat(between(0, 256)); + test( + justType("keyword"), + b -> b.field("test", value), + """ + | FORK + (EVAL j = 1) + (EVAL j = 2) + | EVAL test = LENGTH(test) + """, + matchesList().item(value.length()), + matchesMap().entry("test:column_at_a_time:BytesRefsFromOrds.Singleton", 1), + sig -> assertMap( + sig, + matchesList().item("LuceneSourceOperator") + .item("ValuesSourceReaderOperator") + .item("ProjectOperator") + .item("ExchangeSinkOperator") + ) ); } @@ -217,7 +415,7 @@ private void test( RestEsqlTestCase.RequestObjectBuilder builder = requestObjectBuilder().query(""" FROM test """ + eval + """ - | STATS test = VALUES(test) + | STATS test = MV_SORT(VALUES(test)) """); /* * TODO if you just do KEEP test then the load is in the data node reduce driver and not merged: @@ -265,6 +463,9 @@ private void test( } case "node_reduce" -> logger.info("node_reduce {}", sig); case "final" -> logger.info("final {}", sig); + case "main.final" -> logger.info("main final {}", sig); + case "subplan-0.final" -> logger.info("subplan-0 final {}", sig); + case "subplan-1.final" -> logger.info("subplan-1 final {}", sig); default -> throw new IllegalArgumentException("can't match " + description); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/BlockLoaderWarnings.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/BlockLoaderWarnings.java index 278035dade015..33594a4611f95 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/BlockLoaderWarnings.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/BlockLoaderWarnings.java @@ -39,4 +39,9 @@ public void registerException(Class exceptionClass, String } delegate.registerException(exceptionClass, message); } + + @Override + public String toString() { + return "warnings for " + source; + } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushExpressionsToFieldLoad.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushExpressionsToFieldLoad.java index bec1fcbcd5ef2..e700b77cd7163 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushExpressionsToFieldLoad.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/PushExpressionsToFieldLoad.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules.logical.local; +import org.elasticsearch.index.IndexMode; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; @@ -15,7 +16,6 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.FunctionEsField; -import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.blockloader.BlockLoaderExpression; import org.elasticsearch.xpack.esql.optimizer.LocalLogicalOptimizerContext; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; @@ -42,83 +42,119 @@ public class PushExpressionsToFieldLoad extends ParameterizedRule addedAttrs = new HashMap<>(); - return plan.transformDown(LogicalPlan.class, p -> doRule(p, context, addedAttrs)); + Rule rule = new Rule(context, plan); + return plan.transformDown(LogicalPlan.class, rule::doRule); } - private LogicalPlan doRule( - LogicalPlan plan, - LocalLogicalOptimizerContext context, - Map addedAttrs - ) { - Holder planWasTransformed = new Holder<>(false); - if (plan instanceof Eval || plan instanceof Filter || plan instanceof Aggregate) { - LogicalPlan transformedPlan = plan.transformExpressionsOnly(Expression.class, e -> { - if (e instanceof BlockLoaderExpression ble) { - BlockLoaderExpression.PushedBlockLoaderExpression fuse = ble.tryPushToFieldLoading(context.searchStats()); - if (fuse != null - && context.searchStats() - .supportsLoaderConfig( - fuse.field().fieldName(), - fuse.config(), - context.configuration().pragmas().fieldExtractPreference() - )) { - planWasTransformed.set(true); - return replaceFieldsForFieldTransformations(e, addedAttrs, fuse); + private class Rule { + private final Map addedAttrs = new HashMap<>(); + + private final LocalLogicalOptimizerContext context; + private final LogicalPlan plan; + + /** + * The primary indices, lazily initialized. + */ + private List primaries; + private boolean planWasTransformed = false; + + private Rule(LocalLogicalOptimizerContext context, LogicalPlan plan) { + this.context = context; + this.plan = plan; + } + + private LogicalPlan doRule(LogicalPlan plan) { + planWasTransformed = false; + if (plan instanceof Eval || plan instanceof Filter || plan instanceof Aggregate) { + LogicalPlan transformedPlan = plan.transformExpressionsOnly(Expression.class, e -> { + if (e instanceof BlockLoaderExpression ble) { + return transformExpression(e, ble); } + return e; + }); + + // TODO rebuild everything one time rather than after each find. + if (planWasTransformed == false) { + return plan; } - return e; - }); - if (planWasTransformed.get() == false) { - return plan; + List previousAttrs = transformedPlan.output(); + // Transforms EsRelation to extract the new attributes + List addedAttrsList = addedAttrs.values().stream().toList(); + transformedPlan = transformedPlan.transformDown(EsRelation.class, esRelation -> { + AttributeSet updatedOutput = esRelation.outputSet().combine(AttributeSet.of(addedAttrsList)); + return esRelation.withAttributes(updatedOutput.stream().toList()); + }); + // Transforms Projects so the new attribute is not discarded + transformedPlan = transformedPlan.transformDown(EsqlProject.class, esProject -> { + List projections = new ArrayList<>(esProject.projections()); + projections.addAll(addedAttrsList); + return esProject.withProjections(projections); + }); + + return new EsqlProject(Source.EMPTY, transformedPlan, previousAttrs); } + return plan; + } - List previousAttrs = transformedPlan.output(); - // Transforms EsRelation to extract the new attributes - List addedAttrsList = addedAttrs.values().stream().toList(); - transformedPlan = transformedPlan.transformDown(EsRelation.class, esRelation -> { - AttributeSet updatedOutput = esRelation.outputSet().combine(AttributeSet.of(addedAttrsList)); - return esRelation.withAttributes(updatedOutput.stream().toList()); - }); - // Transforms Projects so the new attribute is not discarded - transformedPlan = transformedPlan.transformDown(EsqlProject.class, esProject -> { - List projections = new ArrayList<>(esProject.projections()); - projections.addAll(addedAttrsList); - return esProject.withProjections(projections); - }); - - return new EsqlProject(Source.EMPTY, transformedPlan, previousAttrs); + private Expression transformExpression(Expression e, BlockLoaderExpression ble) { + BlockLoaderExpression.PushedBlockLoaderExpression fuse = ble.tryPushToFieldLoading(context.searchStats()); + if (fuse == null) { + return e; + } + if (anyPrimaryContains(fuse.field()) == false) { + return e; + } + var preference = context.configuration().pragmas().fieldExtractPreference(); + if (context.searchStats().supportsLoaderConfig(fuse.field().fieldName(), fuse.config(), preference) == false) { + return e; + } + planWasTransformed = true; + return replaceFieldsForFieldTransformations(e, fuse); } - return plan; - } + private Expression replaceFieldsForFieldTransformations(Expression e, BlockLoaderExpression.PushedBlockLoaderExpression fuse) { + // Change the expression to a reference of the pushed down function on the field + FunctionEsField functionEsField = new FunctionEsField(fuse.field().field(), e.dataType(), fuse.config()); + var name = rawTemporaryName(fuse.field().name(), fuse.config().function().toString(), String.valueOf(fuse.config().hashCode())); + var newFunctionAttr = new FieldAttribute( + fuse.field().source(), + fuse.field().parentName(), + fuse.field().qualifier(), + name, + functionEsField, + fuse.field().nullable(), + new NameId(), + true + ); + Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId(); + if (addedAttrs.containsKey(key)) { + return addedAttrs.get(key); + } + + addedAttrs.put(key, newFunctionAttr); + return newFunctionAttr; + } - private static Expression replaceFieldsForFieldTransformations( - Expression e, - Map addedAttrs, - BlockLoaderExpression.PushedBlockLoaderExpression fuse - ) { - // Change the similarity function to a reference of a transformation on the field - FunctionEsField functionEsField = new FunctionEsField(fuse.field().field(), e.dataType(), fuse.config()); - var name = rawTemporaryName(fuse.field().name(), fuse.config().function().toString(), String.valueOf(fuse.config().hashCode())); - // TODO: Check if exists before adding, retrieve the previous one - var newFunctionAttr = new FieldAttribute( - fuse.field().source(), - fuse.field().parentName(), - fuse.field().qualifier(), - name, - functionEsField, - fuse.field().nullable(), - new NameId(), - true - ); - Attribute.IdIgnoringWrapper key = newFunctionAttr.ignoreId(); - if (addedAttrs.containsKey(key)) { - return addedAttrs.get(key); + private List primaries() { + if (primaries == null) { + primaries = new ArrayList<>(2); + plan.forEachUp(EsRelation.class, r -> { + if (r.indexMode() != IndexMode.LOOKUP) { + primaries.add(r); + } + }); + } + return primaries; } - addedAttrs.put(key, newFunctionAttr); - return newFunctionAttr; + private boolean anyPrimaryContains(FieldAttribute attr) { + for (EsRelation primary : primaries()) { + if (primary.outputSet().contains(attr)) { + return true; + } + } + return false; + } } }