diff --git a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java index 39a90f42c5785..1f2db0712079a 100644 --- a/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java +++ b/x-pack/plugin/esql/qa/server/single-node/src/javaRestTest/java/org/elasticsearch/xpack/esql/qa/single_node/PushExpressionToLoadIT.java @@ -45,6 +45,7 @@ import static org.elasticsearch.xpack.esql.qa.single_node.RestEsqlIT.commonProfile; import static org.elasticsearch.xpack.esql.qa.single_node.RestEsqlIT.fixTypesOnProfile; import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; @@ -54,8 +55,9 @@ */ @ThreadLeakFilters(filters = TestClustersThreadFilter.class) public class PushExpressionToLoadIT extends ESRestTestCase { + @ClassRule - public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test")); + public static ElasticsearchCluster cluster = Clusters.testCluster(); @Rule(order = Integer.MIN_VALUE) public ProfileLogger profileLogger = new ProfileLogger(); @@ -365,6 +367,96 @@ public void testVHammingToBit() throws IOException { ); } + // + // Tests without STATS at the end - check that node_reduce phase works correctly + // + public void testLengthPushedWithoutTopN() throws IOException { + String textValue = "v".repeat(between(0, 256)); + test( + b -> b.startObject("test").field("type", "keyword").endObject(), + b -> b.field("test", textValue), + """ + FROM test + | EVAL fieldLength = LENGTH(test) + | LIMIT 10 + | KEEP test, fieldLength + """, + matchesList().item(textValue).item(textValue.length()), + matchesList().item(matchesMap().entry("name", "test").entry("type", any(String.class))) + .item(matchesMap().entry("name", "fieldLength").entry("type", any(String.class))), + Map.of( + "data", + List.of( + // Pushed down function + matchesMap().entry("test:column_at_a_time:BytesRefsFromOrds.Singleton", 1), + // Field + matchesMap().entry("test:row_stride:BytesRefsFromOrds.Singleton", 1) + ) + ), + sig -> {} + ); + } + + public void testLengthPushedWithTopN() throws IOException { + String textValue = "v".repeat(between(0, 256)); + Integer orderingValue = randomInt(); + test(b -> { + b.startObject("test").field("type", "keyword").endObject(); + b.startObject("ordering").field("type", "integer").endObject(); + }, + b -> b.field("test", textValue).field("ordering", orderingValue), + """ + FROM test + | EVAL fieldLength = LENGTH(test) + | SORT ordering DESC + | LIMIT 10 + | KEEP test + """, + matchesList().item(textValue), + matchesList().item(matchesMap().entry("name", "test").entry("type", any(String.class))), + Map.of( + "data", + List.of(matchesMap().entry("ordering:column_at_a_time:IntsFromDocValues.Singleton", 1)), + "node_reduce", + List.of( + // Pushed down function + matchesMap().entry("test:column_at_a_time:Utf8CodePointsFromOrds.Singleton", 1), + // Field + matchesMap().entry("test:row_stride:BytesRefsFromOrds.Singleton", 1) + ) + ), + sig -> {} + ); + } + + public void testLengthPushedWithTopNAsOrder() throws IOException { + String textValue = "v".repeat(between(0, 256)); + test( + b -> b.startObject("test").field("type", "keyword").endObject(), + b -> b.field("test", textValue), + """ + FROM test + | EVAL fieldLength = LENGTH(test) + | SORT fieldLength DESC + | LIMIT 10 + | KEEP test, fieldLength + """, + matchesList().item(textValue).item(textValue.length()), + matchesList().item(matchesMap().entry("name", "test").entry("type", any(String.class))) + .item(matchesMap().entry("name", "fieldLength").entry("type", any(String.class))), + Map.of( + "data", + List.of( + // Pushed down function + matchesMap().entry("test:column_at_a_time:Utf8CodePointsFromOrds.Singleton", 1), + // TODO It should not load the field value on the data node, but just on the node_reduce phase + matchesMap().entry("test:row_stride:BytesRefsFromOrds.Singleton", 1) + ) + ), + sig -> {} + ); + } + // // Tests for more complex shapes. // @@ -639,23 +731,34 @@ private void test( MapMatcher expectedLoaders, Consumer> assertDataNodeSig ) throws IOException { + + test( + mapping, + doc, + """ + FROM test + """ + eval + """ + | STATS test = MV_SORT(VALUES(test)) + """, + expectedValue, + matchesList().item(matchesMap().entry("name", "test").entry("type", any(String.class))), + Map.of("data", List.of(expectedLoaders)), + assertDataNodeSig + ); + } + + private void test( + CheckedConsumer mapping, + CheckedConsumer doc, + String query, + Matcher expectedValue, + Matcher columnMatcher, + Map> expectedLoadersPerDriver, + Consumer> assertDataNodeSig + ) throws IOException { indexValue(mapping, doc); - RestEsqlTestCase.RequestObjectBuilder builder = requestObjectBuilder().query(""" - FROM test - """ + eval + """ - | STATS test = MV_SORT(VALUES(test)) - """); - /* - * TODO if you just do KEEP test then the load is in the data node reduce driver and not merged: - * \_ProjectExec[[test{f}#7]] - * \_FieldExtractExec[test{f}#7]<[],[]> - * \_EsQueryExec[test], indexMode[standard]] - * \_ExchangeSourceExec[[test{f}#7],false]}, {cluster_name=test-cluster, node_name=test-cluster-0, descrip - * \_ProjectExec[[test{r}#3]] - * \_EvalExec[[LENGTH(test{f}#7) AS test#3]] - * \_LimitExec[1000[INTEGER],50] - * \_ExchangeSourceExec[[test{f}#7],false]}], query={to - */ + RestEsqlTestCase.RequestObjectBuilder builder = requestObjectBuilder().query(query); + builder.profile(true); Map result = runEsql(builder, new AssertWarnings.NoWarnings(), profileLogger, RestEsqlTestCase.Mode.SYNC); @@ -669,7 +772,7 @@ private void test( .entry("planning", matchesMap().extraOk()) .entry("query", matchesMap().extraOk()) ), - matchesList().item(matchesMap().entry("name", "test").entry("type", any(String.class))), + columnMatcher, matchesList().item(expectedValue) ); @SuppressWarnings("unchecked") @@ -677,14 +780,13 @@ private void test( for (Map p : profiles) { fixTypesOnProfile(p); assertThat(p, commonProfile()); - List sig = new ArrayList<>(); @SuppressWarnings("unchecked") List> operators = (List>) p.get("operators"); - for (Map o : operators) { - sig.add(checkOperatorProfile(o, expectedLoaders)); - } - String description = p.get("description").toString(); - switch (description) { + + String driverDescription = (String) p.get("description"); + List mapMatcher = expectedLoadersPerDriver.get(driverDescription); + List sig = checkOperatorProfile(driverDescription, operators, mapMatcher); + switch (driverDescription) { case "data" -> { logger.info("data {}", sig); assertDataNodeSig.accept(sig); @@ -694,7 +796,7 @@ private void test( case "main.final" -> logger.info("main final {}", sig); case "subplan-0.final" -> logger.info("subplan-0 final {}", sig); case "subplan-1.final" -> logger.info("subplan-1 final {}", sig); - default -> throw new IllegalArgumentException("can't match " + description); + default -> throw new IllegalArgumentException("can't match " + driverDescription); } } } @@ -793,48 +895,36 @@ private void initLookupIndex() throws IOException { } private CheckedConsumer justType(String type) { - return b -> b.startObject("test").field("type", type).endObject(); - } - - private static String checkOperatorProfile(Map o, MapMatcher expectedLoaders) { - String name = (String) o.get("operator"); - name = PushQueriesIT.TO_NAME.matcher(name).replaceAll(""); - if (name.equals("ValuesSourceReaderOperator")) { - MapMatcher expectedOp = matchesMap().entry("operator", startsWith(name)) - .entry("status", matchesMap().entry("readers_built", expectedLoaders).extraOk()); - assertMap(o, expectedOp); + return justType("test", type); + } + + private CheckedConsumer justType(String fieldName, String type) { + return b -> b.startObject(fieldName).field("type", type).endObject(); + } + + private static List checkOperatorProfile( + String driverDesc, + List> operators, + List expectedLoaders + ) { + List sig = new ArrayList<>(); + for (Map operator : operators) { + String name = (String) operator.get("operator"); + name = PushQueriesIT.TO_NAME.matcher(name).replaceAll(""); + if (name.equals("ValuesSourceReaderOperator")) { + assertNotNull("Expected loaders to match the ValuesSourceReaderOperator for driver " + driverDesc, expectedLoaders); + MapMatcher expectedOp = matchesMap().entry("operator", startsWith(name)) + .entry("status", matchesMap().entry("readers_built", anyOf(expectedLoaders.toArray(new MapMatcher[0]))).extraOk()); + assertMap("Error checking values loaded for driver " + driverDesc + "; ", operator, expectedOp); + } + sig.add(name); } - return name; + + return sig; } @Override protected String getTestRestCluster() { return cluster.getHttpAddresses(); } - - @Override - protected boolean preserveClusterUponCompletion() { - // Preserve the cluser to speed up the semantic_text tests - return true; - } - - private static boolean setupEmbeddings = false; - - private void setUpTextEmbeddingInferenceEndpoint() throws IOException { - setupEmbeddings = true; - Request request = new Request("PUT", "_inference/text_embedding/test"); - request.setJsonEntity(""" - { - "service": "text_embedding_test_service", - "service_settings": { - "model": "my_model", - "api_key": "abc64", - "dimensions": 128 - }, - "task_settings": { - } - } - """); - adminClient().performRequest(request); - } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index 0804ba1718b53..5c09bf8fe9e1d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -44,11 +44,13 @@ import org.elasticsearch.xpack.esql.expression.function.fulltext.SingleFieldFullTextFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.expression.function.scalar.string.Length; import org.elasticsearch.xpack.esql.expression.function.scalar.string.StartsWith; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.RLikeList; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.regex.WildcardLikeList; +import org.elasticsearch.xpack.esql.expression.function.vector.DotProduct; import org.elasticsearch.xpack.esql.expression.function.vector.VectorSimilarityFunction; import org.elasticsearch.xpack.esql.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.expression.predicate.nulls.IsNotNull; @@ -65,22 +67,33 @@ import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Fork; import org.elasticsearch.xpack.esql.plan.logical.Limit; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.Row; +import org.elasticsearch.xpack.esql.plan.logical.Subquery; import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.UnionAll; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; import org.elasticsearch.xpack.esql.plan.logical.local.EmptyLocalSupplier; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; +import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec; +import org.elasticsearch.xpack.esql.plan.physical.EvalExec; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; +import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; +import org.elasticsearch.xpack.esql.plan.physical.TopNExec; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; +import org.elasticsearch.xpack.esql.planner.mapper.Mapper; import org.elasticsearch.xpack.esql.rule.RuleExecutor; import org.elasticsearch.xpack.esql.session.Configuration; +import org.elasticsearch.xpack.esql.session.Versioned; import org.elasticsearch.xpack.esql.stats.SearchStats; import org.junit.BeforeClass; @@ -95,6 +108,7 @@ import static java.util.Collections.emptyMap; import static org.elasticsearch.xpack.esql.EsqlTestUtils.L; import static org.elasticsearch.xpack.esql.EsqlTestUtils.ONE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_SEARCH_STATS; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; @@ -111,6 +125,7 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.testAnalyzerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.unboundLogicalOptimizerContext; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; +import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.indexResolutions; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.type.DataType.DENSE_VECTOR; @@ -118,7 +133,9 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.KEYWORD; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.DOWN; import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; +import static org.elasticsearch.xpack.esql.planner.PlannerUtils.breakPlanBetweenCoordinatorAndDataNode; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; @@ -152,6 +169,7 @@ public static void init() { EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), indexResolutions(test), + defaultLookupResolution(), emptyPolicyResolution(), emptyInferenceResolution() ), @@ -1878,6 +1896,329 @@ public void testKnnOnMissingField() { assertThat(Expressions.name(fullTextFunction.field()), equalTo("text")); } + private static PhysicalPlan physicalPlan(LogicalPlan logicalPlan, Analyzer analyzer) { + var mapper = new Mapper(); + return mapper.map(new Versioned<>(logicalPlan, analyzer.context().minimumVersion())); + } + + public void testReductionPlanForTopNWithPushedDownFunctions() { + var query = String.format(Locale.ROOT, """ + FROM test_all + | EVAL score = V_DOT_PRODUCT(dense_vector, [1.0, 2.0, 3.0]) + | SORT integer DESC + | LIMIT 10 + | KEEP text, score + """); + var logicalPlan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); + + // Verify the logical plan structure: + // EsqlProject[[text{f}#1105, score{r}#1085]] + var project = as(logicalPlan, EsqlProject.class); + assertThat(Expressions.names(project.projections()), contains("text", "score")); + + // TopN[[Order[integer{f}#1099,DESC,FIRST]],10[INTEGER],false] + var topN = as(project.child(), TopN.class); + assertThat(topN.limit().fold(FoldContext.small()), equalTo(10)); + var order = as(topN.order().getFirst(), Order.class); + assertThat(order.direction(), equalTo(Order.OrderDirection.DESC)); + var orderField = as(order.child(), FieldAttribute.class); + assertThat(orderField.name(), equalTo("integer")); + + // Eval[[$$dense_vector$V_DOT_PRODUCT$1451583510{f$}#1110 AS score#1085]] + var eval = as(topN.child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + var scoreAlias = eval.fields() + .stream() + .filter(f -> f.name().equals("score")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 'score' not found in eval")); + var scoreField = as(scoreAlias, Alias.class); + var scoreFieldAttr = as(scoreField.child(), FieldAttribute.class); + assertThat(scoreFieldAttr.name(), startsWith("$$dense_vector$V_DOT_PRODUCT$")); + assertThat(scoreFieldAttr.fieldName().string(), equalTo("dense_vector")); + + // EsRelation[test_all][!alias_integer, boolean{f}#1090, byte{f}#1091, cons..] + var relation = as(eval.child(), EsRelation.class); + assertTrue(relation.output().contains(scoreFieldAttr)); + + // Also verify physical plan behavior + var physicalPlan = physicalPlan(logicalPlan, allTypesAnalyzer); + var coordAndDataNodePlans = breakPlanBetweenCoordinatorAndDataNode(physicalPlan, TEST_CFG); + + var coordPlan = coordAndDataNodePlans.v1(); + var coordProjectExec = as(coordPlan, ProjectExec.class); + assertThat(coordProjectExec.projections().stream().map(NamedExpression::name).toList(), containsInAnyOrder("text", "score")); + var coordTopN = as(coordProjectExec.child(), TopNExec.class); + var orderAttr = as(coordTopN.order().getFirst().child(), FieldAttribute.class); + assertThat(orderAttr.name(), equalTo("integer")); + + var reductionPlan = ((PlannerUtils.TopNReduction) PlannerUtils.reductionPlan(coordAndDataNodePlans.v2())).plan(); + var topNExec = as(reductionPlan, TopNExec.class); + var evalExec = as(topNExec.child(), EvalExec.class); + var alias = evalExec.fields().get(0); + assertThat(alias.name(), equalTo("score")); + var fieldAttr = as(alias.child(), FieldAttribute.class); + assertThat(fieldAttr.name(), startsWith("$$dense_vector$V_DOT_PRODUCT$")); + var esSourceExec = as(evalExec.child(), EsSourceExec.class); + assertTrue(esSourceExec.outputSet().stream().anyMatch(a -> a == fieldAttr)); + } + + public void testReductionPlanForTopNWithPushedDownFunctionsInOrder() { + var query = String.format(Locale.ROOT, """ + FROM test_all + | EVAL fieldLength = LENGTH(text) + | SORT fieldLength DESC + | LIMIT 10 + | KEEP text, fieldLength + """); + var logicalPlan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); + var physicalPlan = physicalPlan(logicalPlan, allTypesAnalyzer); + var coordAndDataNodePlans = breakPlanBetweenCoordinatorAndDataNode(physicalPlan, TEST_CFG); + + var coordPlan = coordAndDataNodePlans.v1(); + var coordProjectExec = as(coordPlan, ProjectExec.class); + assertThat(coordProjectExec.projections().stream().map(NamedExpression::name).toList(), containsInAnyOrder("text", "fieldLength")); + var coordTopN = as(coordProjectExec.child(), TopNExec.class); + var orderAttr = as(coordTopN.order().getFirst().child(), ReferenceAttribute.class); + assertThat(orderAttr.name(), equalTo("fieldLength")); + + var reductionPlan = ((PlannerUtils.TopNReduction) PlannerUtils.reductionPlan(coordAndDataNodePlans.v2())).plan(); + var topN = as(reductionPlan, TopNExec.class); + var eval = as(topN.child(), EvalExec.class); + var alias = eval.fields().get(0); + assertThat(alias.name(), equalTo("fieldLength")); + var fieldAttr = as(alias.child(), FieldAttribute.class); + assertThat(fieldAttr.name(), startsWith("$$text$LENGTH$")); + var esSourceExec = as(eval.child(), EsSourceExec.class); + assertTrue(esSourceExec.outputSet().stream().anyMatch(a -> a == fieldAttr)); + } + + public void testPushableFunctionsInFork() { + assumeTrue("requires functions pushdown", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); + var query = """ + from test_all + | eval u = v_cosine(dense_vector, [4, 5, 6]) + | fork + (eval s = length(text) | keep s, u, keyword) + (eval t = v_dot_product(dense_vector, [1, 2, 3]) | keep t, u, keyword) + | eval x = length(keyword) + """; + var localPlan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); + + var eval = as(localPlan, Eval.class); + // Cosine function has not been pushed down as it targets a reference and not a field + assertThat(eval.fields().getFirst().child(), instanceOf(Length.class)); + var limit = as(eval.child(), Limit.class); + var fork = as(limit.child(), Fork.class); + assertThat(fork.children(), hasSize(2)); + + // First branch: (eval s = length(text) | keep s, u, keyword) + var project1 = as(fork.children().get(0), EsqlProject.class); + assertThat(Expressions.names(project1.projections()), containsInAnyOrder("s", "_fork", "t", "u", "keyword")); + var eval1 = as(project1.child(), Eval.class); + assertThat(eval1.fields(), hasSize(4)); + + // Find the "s" field which should be a pushed down LENGTH function + var sAlias = eval1.fields() + .stream() + .filter(f -> f.name().equals("s")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 's' not found in eval")); + var sField = as(sAlias, Alias.class); + var sFieldAttr = as(sField.child(), FieldAttribute.class); + assertThat(sFieldAttr.name(), startsWith("$$text$LENGTH$")); + assertThat(sFieldAttr.fieldName().string(), equalTo("text")); + + // Find the "u" field which should be a pushed down V_COSINE function + var u1Alias = eval1.fields() + .stream() + .filter(f -> f.name().equals("u")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 's' not found in eval")); + var u1Field = as(u1Alias, Alias.class); + var u1FieldAttr = as(u1Field.child(), FieldAttribute.class); + assertThat(u1FieldAttr.name(), startsWith("$$dense_vector$V_COSINE$")); + assertThat(u1FieldAttr.fieldName().string(), equalTo("dense_vector")); + + var limit1 = as(eval1.child(), Limit.class); + // EsRelation[test_all] - verify pushed down field is in the relation output + var relation1 = as(limit1.child(), EsRelation.class); + assertTrue(relation1.output().contains(sFieldAttr)); + + // Second branch: (eval t = v_dot_product(dense_vector, [1, 2, 3]) | keep t, u, keyword) + // EsqlProject[[s{r}#55, _fork{r}#4, t{r}#11]] + var project2 = as(fork.children().get(1), EsqlProject.class); + assertThat(Expressions.names(project2.projections()), containsInAnyOrder("s", "_fork", "t", "u", "keyword")); + + // Eval[[$$dense_vector$V_DOT_PRODUCT$-1468139866{f$}#60 AS t#11, fork2[KEYWORD] AS _fork#4, null[INTEGER] AS s#55]] + var eval2 = as(project2.child(), Eval.class); + assertThat(eval2.fields(), hasSize(4)); + + // Find the "t" field which should be a pushed down V_DOT_PRODUCT function + var tAlias = eval2.fields() + .stream() + .filter(f -> f.name().equals("t")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 't' not found in eval")); + var tField = as(tAlias, Alias.class); + var tFieldAttr = as(tField.child(), FieldAttribute.class); + assertThat(tFieldAttr.name(), startsWith("$$dense_vector$V_DOT_PRODUCT$")); + assertThat(tFieldAttr.fieldName().string(), equalTo("dense_vector")); + + // Find the "u" field which should be the same pushed down V_COSINE function + var u2Alias = eval1.fields() + .stream() + .filter(f -> f.name().equals("u")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 's' not found in eval")); + var u2Field = as(u2Alias, Alias.class); + assertThat(u1Field, equalTo(u2Field)); + + // Limit[1000[INTEGER],false,false] + var limit2 = as(eval2.child(), Limit.class); + + // EsRelation[test_all] - verify pushed down field is in the relation output + var relation2 = as(limit2.child(), EsRelation.class); + assertTrue(relation2.output().contains(tFieldAttr)); + } + + public void testPushableFunctionsInSubqueries() { + assumeTrue("requires functions pushdown", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); + var query = """ + from test_all, (from test_all | eval s = length(text) | keep s) + | eval t = v_dot_product(dense_vector, [1, 2, 3]) + | keep s, t + """; + var localPlan = localPlan(plan(query, allTypesAnalyzer), TEST_SEARCH_STATS); + + // EsqlProject[[s{r}#97, t{r}#9]] + var project = as(localPlan, EsqlProject.class); + assertThat(Expressions.names(project.projections()), contains("s", "t")); + + // Eval[[DOTPRODUCT(dense_vector{r}#82,[1.0, 2.0, 3.0][DENSE_VECTOR]) AS t#9]] + var eval = as(project.child(), Eval.class); + assertThat(eval.fields(), hasSize(1)); + // Find the "t" field which should be a NOT pushed down LENGTH function + var tAlias = eval.fields() + .stream() + .filter(f -> f.name().equals("t")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 't' not found in subquery eval")); + assertThat(tAlias.child(), instanceOf(DotProduct.class)); + + // Limit[1000[INTEGER],false,false] + var limit = as(eval.child(), Limit.class); + assertThat(limit.limit().fold(FoldContext.small()), equalTo(1000)); + + // UnionAll[[alias_integer{r}#76, boolean{r}#77, byte{r}#78, ...]] + var unionAll = as(limit.child(), UnionAll.class); + assertThat(unionAll.children(), hasSize(2)); + + // Second branch of UnionAll - contains the subquery + // EsqlProject[[alias_integer{r}#99, boolean{r}#56, ...]] + var project2 = as(unionAll.children().get(1), EsqlProject.class); + + // Eval[[null[KEYWORD] AS alias_integer#55, null[BOOLEAN] AS boolean#56, ...]] + var eval2 = as(project2.child(), Eval.class); + + var subquery = as(eval2.child(), Subquery.class); + var subqueryProject = as(subquery.child(), EsqlProject.class); + assertThat(Expressions.names(subqueryProject.projections()), contains("s")); + var subqueryEval = as(subqueryProject.child(), Eval.class); + assertThat(subqueryEval.fields(), hasSize(1)); + + // Find the "s" field which should be a pushed down LENGTH function + var sAlias = subqueryEval.fields() + .stream() + .filter(f -> f.name().equals("s")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 's' not found in subquery eval")); + var sField = as(sAlias, Alias.class); + var sFieldAttr = as(sField.child(), FieldAttribute.class); + assertThat(sFieldAttr.name(), startsWith("$$text$LENGTH$")); + assertThat(sFieldAttr.fieldName().string(), equalTo("text")); + var subqueryLimit = as(subqueryEval.child(), Limit.class); + // EsRelation[test_all] - verify pushed down field is in the relation output + var subqueryRelation = as(subqueryLimit.child(), EsRelation.class); + assertTrue(subqueryRelation.output().contains(sFieldAttr)); + } + + public void testPushDownFunctionsLookupJoin() { + assumeTrue("requires functions pushdown", EsqlCapabilities.Cap.VECTOR_SIMILARITY_FUNCTIONS_PUSHDOWN.isEnabled()); + + var query = """ + from test + | eval s = length(first_name) + | rename languages AS language_code + | keep s, language_code, last_name + | lookup join languages_lookup ON language_code + | eval t = length(last_name) + | eval u = length(language_name) + """; + + var localPlan = localPlan(plan(query, analyzer), TEST_SEARCH_STATS); + + // Project[[s{r}#124, languages{f}#141 AS language_code#127, last_name{f}#142, language_name{f}#150, t{r}#134, u{r}#137]] + var project = as(localPlan, Project.class); + assertThat(Expressions.names(project.projections()), contains("s", "language_code", "last_name", "language_name", "t", "u")); + + // Eval[[$$last_name$LENGTH$1912486003{f$}#151 AS t#134, LENGTH(language_name{f}#150) AS u#137]] + var eval = as(project.child(), Eval.class); + assertThat(eval.fields(), hasSize(2)); + + // Find the "t" field which should be a pushed down LENGTH function on last_name + var tAlias = eval.fields() + .stream() + .filter(f -> f.name().equals("t")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 't' not found in eval")); + var tField = as(tAlias, Alias.class); + var tFieldAttr = as(tField.child(), FieldAttribute.class); + assertThat(tFieldAttr.name(), startsWith("$$last_name$LENGTH$")); + assertThat(tFieldAttr.fieldName().string(), equalTo("last_name")); + + // Find the "u" field which should NOT be pushed down - it's LENGTH(language_name{f}#150) + var uAlias = eval.fields() + .stream() + .filter(f -> f.name().equals("u")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 'u' not found in eval")); + var uField = as(uAlias, Alias.class); + var uLength = as(uField.child(), Length.class); + assertThat(Expressions.name(uLength.field()), equalTo("language_name")); + + var limit = as(eval.child(), Limit.class); + var join = as(limit.child(), Join.class); + assertThat(join.config().type(), equalTo(JoinTypes.LEFT)); + + // Left side of join: Eval[[$$first_name$LENGTH$1912486003{f$}#152 AS s#124]] + var leftEval = as(join.left(), Eval.class); + assertThat(leftEval.fields(), hasSize(1)); + + // Find the "s" field which should be a pushed down LENGTH function on first_name + var sAlias = leftEval.fields() + .stream() + .filter(f -> f.name().equals("s")) + .findFirst() + .orElseThrow(() -> new AssertionError("Field 's' not found in left eval")); + var sField = as(sAlias, Alias.class); + var sFieldAttr = as(sField.child(), FieldAttribute.class); + assertThat(sFieldAttr.name(), startsWith("$$first_name$LENGTH$")); + assertThat(sFieldAttr.fieldName().string(), equalTo("first_name")); + + // Limit[1000[INTEGER],false,false] + var leftLimit = as(leftEval.child(), Limit.class); + + // EsRelation[test] - verify pushed down field is in the relation output + var leftRelation = as(leftLimit.child(), EsRelation.class); + assertTrue(leftRelation.output().contains(sFieldAttr)); + + // Right side of join: EsRelation[languages_lookup][LOOKUP][language_code{f}#149, language_name{f}#150, $$last_..] + var rightRelation = as(join.right(), EsRelation.class); + // Verify that the pushed down field t (last_name length) is in the lookup relation output + assertTrue(rightRelation.output().contains(tFieldAttr)); + } + private IsNotNull isNotNull(Expression field) { return new IsNotNull(EMPTY, field); }