diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index d5fe1b4a697e0..30d807a11fb56 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -197,7 +197,7 @@ private static Operator operator(DriverContext driverContext, String grouping, S }; return new HashAggregationOperator( List.of(supplier(op, dataType, filter).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(groups.size()))), - () -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false), + () -> BlockHash.build(groups, driverContext.blockFactory(), 16 * 1024, false, TOP_N_LIMIT), driverContext ); } diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java index dfd56996e1c15..e28f327075327 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/ValuesAggregatorBenchmark.java @@ -123,7 +123,7 @@ private static Operator operator(DriverContext driverContext, int groups, String List groupSpec = List.of(new BlockHash.GroupSpec(0, ElementType.LONG)); return new HashAggregationOperator( List.of(supplier(dataType).groupingAggregatorFactory(AggregatorMode.SINGLE, List.of(1))), - () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false), + () -> BlockHash.build(groupSpec, driverContext.blockFactory(), 16 * 1024, false, 100), driverContext ) { @Override diff --git a/docs/changelog/130111.yaml b/docs/changelog/130111.yaml new file mode 100644 index 0000000000000..78f0610e83aed --- /dev/null +++ b/docs/changelog/130111.yaml @@ -0,0 +1,5 @@ +pr: 130111 +summary: Plug TopN agg groupings filtering into the language +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java index 63f4d9c96bcd0..52243a3ce74cc 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/blockhash/BlockHash.java @@ -159,12 +159,20 @@ public boolean isCategorize() { * null handling and remove this flag, but we need to disable these in * production until we can. And this lets us continue to compile and * test them. + * @param maxTopNLimit the maximum limit for TopN groups to use a TopNBlockHash. + * This usually comes from {@code QueryPragma.maxTopNAggsLimit()}. */ - public static BlockHash build(List groups, BlockFactory blockFactory, int emitBatchSize, boolean allowBrokenOptimizations) { + public static BlockHash build( + List groups, + BlockFactory blockFactory, + int emitBatchSize, + boolean allowBrokenOptimizations, + int maxTopNLimit + ) { if (groups.size() == 1) { GroupSpec group = groups.get(0); - if (group.topNDef() != null && group.elementType() == ElementType.LONG) { - TopNDef topNDef = group.topNDef(); + TopNDef topNDef = group.topNDef(); + if (topNDef != null && maxTopNLimit > 0 && group.elementType() == ElementType.LONG && topNDef.limit() < maxTopNLimit) { return new LongTopNBlockHash(group.channel(), topNDef.asc(), topNDef.nullsFirst(), topNDef.limit(), blockFactory); } return newForElementType(group.channel(), group.elementType(), blockFactory); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java index cbce712ed9cdb..64d7ec7a19f16 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/HashAggregationOperator.java @@ -46,7 +46,8 @@ public record HashAggregationOperatorFactory( AggregatorMode aggregatorMode, List aggregators, int maxPageSize, - AnalysisRegistry analysisRegistry + AnalysisRegistry analysisRegistry, + int maxTopNLimit ) implements OperatorFactory { @Override public Operator get(DriverContext driverContext) { @@ -65,7 +66,7 @@ public Operator get(DriverContext driverContext) { } return new HashAggregationOperator( aggregators, - () -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false), + () -> BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false, maxTopNLimit), driverContext ); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java index 6ab0291c718a7..d97d52314d9b8 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/TimeSeriesAggregationOperator.java @@ -35,7 +35,8 @@ public record Factory( List groups, AggregatorMode aggregatorMode, List aggregators, - int maxPageSize + int maxPageSize, + int maxTopNLimit ) implements OperatorFactory { @Override public Operator get(DriverContext driverContext) { @@ -48,7 +49,8 @@ public Operator get(DriverContext driverContext) { groups, driverContext.blockFactory(), maxPageSize, - true // we can enable optimizations as the inputs are vectors + true, // we can enable optimizations as the inputs are vectors + maxTopNLimit ); } }, driverContext); diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java index f040b86850133..fffce4f3a524b 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/GroupingAggregatorFunctionTestCase.java @@ -117,7 +117,8 @@ private Operator.OperatorFactory simpleWithMode( mode, List.of(supplier.groupingAggregatorFactory(mode, channels(mode))), randomPageSize(), - null + null, + 100 ); } else { return new RandomizingHashAggregationOperatorFactory( @@ -125,7 +126,8 @@ private Operator.OperatorFactory simpleWithMode( mode, List.of(supplier.groupingAggregatorFactory(mode, channels(mode))), randomPageSize(), - null + null, + 100 ); } } @@ -824,7 +826,8 @@ private record RandomizingHashAggregationOperatorFactory( AggregatorMode aggregatorMode, List aggregators, int maxPageSize, - AnalysisRegistry analysisRegistry + AnalysisRegistry analysisRegistry, + int maxTopNLimit ) implements Operator.OperatorFactory { @Override @@ -838,7 +841,7 @@ public Operator get(DriverContext driverContext) { analysisRegistry, maxPageSize ) - : BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false); + : BlockHash.build(groups, driverContext.blockFactory(), maxPageSize, false, maxTopNLimit); return new BlockHashWrapper(driverContext.blockFactory(), blockHash) { @Override @@ -886,7 +889,8 @@ public String describe() { aggregatorMode, aggregators, maxPageSize, - analysisRegistry + analysisRegistry, + maxTopNLimit ).describe(); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java index 990827b3dc693..a5a38d0f8223a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashRandomizedTests.java @@ -256,7 +256,7 @@ private BlockHash newBlockHash(BlockFactory blockFactory, int emitBatchSize, Lis } return forcePackedHash ? new PackedValuesBlockHash(specs, blockFactory, emitBatchSize) - : BlockHash.build(specs, blockFactory, emitBatchSize, true); + : BlockHash.build(specs, blockFactory, emitBatchSize, true, 100); } private static final int LOOKUP_POSITIONS = 1_000; diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java index be97055eb9a7e..cb988778573f5 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/BlockHashTests.java @@ -47,6 +47,8 @@ import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.startsWith; public class BlockHashTests extends BlockHashTestCase { @@ -1544,7 +1546,8 @@ public void testTimeSeriesBlockHash() throws Exception { List.of(new BlockHash.GroupSpec(0, ElementType.BYTES_REF), new BlockHash.GroupSpec(1, ElementType.LONG)), blockFactory, 32 * 1024, - forcePackedHash + forcePackedHash, + 100 ); int numPages = between(1, 100); int globalTsid = -1; @@ -1660,6 +1663,34 @@ public void close() { } } + public void testTopNBlockHashLimit() { + int limit = randomIntBetween(1, 100); + + try ( + var hash = BlockHash.build( + List.of(new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, false, limit))), + blockFactory, + 32 * 1024, + forcePackedHash, + randomIntBetween(limit, 1000) + ) + ) { + assertThat(hash, instanceOf(LongTopNBlockHash.class)); + } + + try ( + var hash = BlockHash.build( + List.of(new BlockHash.GroupSpec(0, ElementType.LONG, null, new BlockHash.TopNDef(0, true, false, limit))), + blockFactory, + 32 * 1024, + forcePackedHash, + randomIntBetween(1, limit) + ) + ) { + assertThat(hash, not(instanceOf(LongTopNBlockHash.class))); + } + } + /** * Hash some values into a single block of group ids. If the hash produces * more than one block of group ids this will fail. @@ -1713,6 +1744,6 @@ private BlockHash buildBlockHash(int emitBatchSize, Block... values) { } return forcePackedHash ? new PackedValuesBlockHash(specs, blockFactory, emitBatchSize) - : BlockHash.build(specs, blockFactory, emitBatchSize, true); + : BlockHash.build(specs, blockFactory, emitBatchSize, true, 100); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java index 9ce086307acee..aadebb1c57685 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizeBlockHashTests.java @@ -513,7 +513,8 @@ public void testCategorize_withDriver() { new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)) ), 16 * 1024, - analysisRegistry + analysisRegistry, + 100 ).get(driverContext) ), new PageConsumerOperator(intermediateOutput::add) @@ -532,7 +533,8 @@ public void testCategorize_withDriver() { new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(1)) ), 16 * 1024, - analysisRegistry + analysisRegistry, + 100 ).get(driverContext) ), new PageConsumerOperator(intermediateOutput::add) @@ -553,7 +555,8 @@ public void testCategorize_withDriver() { new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.FINAL, List.of(3, 4)) ), 16 * 1024, - analysisRegistry + analysisRegistry, + 100 ).get(driverContext) ), new PageConsumerOperator(finalOutput::add) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java index d0eb89eafd841..346ea21631d5d 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/blockhash/CategorizePackedValuesBlockHashTests.java @@ -151,7 +151,8 @@ public void testCategorize_withDriver() { AggregatorMode.INITIAL, List.of(new ValuesBytesRefAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(0))), 16 * 1024, - analysisRegistry + analysisRegistry, + 100 ).get(driverContext) ), new PageConsumerOperator(intermediateOutput::add) @@ -167,7 +168,8 @@ public void testCategorize_withDriver() { AggregatorMode.INITIAL, List.of(new ValuesBytesRefAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.INITIAL, List.of(0))), 16 * 1024, - analysisRegistry + analysisRegistry, + 100 ).get(driverContext) ), new PageConsumerOperator(intermediateOutput::add) @@ -185,7 +187,8 @@ public void testCategorize_withDriver() { AggregatorMode.FINAL, List.of(new ValuesBytesRefAggregatorFunctionSupplier().groupingAggregatorFactory(AggregatorMode.FINAL, List.of(2))), 16 * 1024, - analysisRegistry + analysisRegistry, + 100 ).get(driverContext) ), new PageConsumerOperator(finalOutput::add) diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index 0e9c0e33d22cd..18d76cbe07354 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -64,7 +64,8 @@ protected Operator.OperatorFactory simpleWithMode(SimpleOptions options, Aggrega new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, maxChannels) ), randomPageSize(), - null + null, + 100 ); } @@ -103,6 +104,7 @@ protected void assertSimpleOutput(List input, List results) { public void testTopNNullsLast() { boolean ascOrder = randomBoolean(); + int limit = 3; var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L }; if (ascOrder) { Arrays.sort(groups, Comparator.reverseOrder()); @@ -113,14 +115,15 @@ public void testTopNNullsLast() { try ( var operator = new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, 3))), + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, false, limit))), mode, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels), new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels) ), randomPageSize(), - null + null, + randomIntBetween(limit, 1000) ).get(driverContext()) ) { var page = new Page( @@ -180,6 +183,7 @@ public void testTopNNullsLast() { public void testTopNNullsFirst() { boolean ascOrder = randomBoolean(); + int limit = 3; var groups = new Long[] { 0L, 10L, 20L, 30L, 40L, 50L }; if (ascOrder) { Arrays.sort(groups, Comparator.reverseOrder()); @@ -190,14 +194,15 @@ public void testTopNNullsFirst() { try ( var operator = new HashAggregationOperator.HashAggregationOperatorFactory( - List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, true, 3))), + List.of(new BlockHash.GroupSpec(groupChannel, ElementType.LONG, null, new BlockHash.TopNDef(0, ascOrder, true, limit))), mode, List.of( new SumLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels), new MaxLongAggregatorFunctionSupplier().groupingAggregatorFactory(mode, aggregatorChannels) ), randomPageSize(), - null + null, + randomIntBetween(limit, 1000) ).get(driverContext()) ) { var page = new Page( diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java index ef02d4a1f8c98..143c028c77194 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/EsqlQueryGenerator.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe.RenameGenerator; import org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe.SortGenerator; import org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe.StatsGenerator; +import org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe.TopNStatsGenerator; import org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe.WhereGenerator; import org.elasticsearch.xpack.esql.qa.rest.generative.command.source.FromGenerator; @@ -65,6 +66,7 @@ public record QueryExecuted(String query, int depth, List outputSchema, RenameGenerator.INSTANCE, SortGenerator.INSTANCE, StatsGenerator.INSTANCE, + TopNStatsGenerator.INSTANCE, WhereGenerator.INSTANCE ); diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/StatsGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/StatsGenerator.java index b0ce7f43af997..4a5a04fa9c99d 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/StatsGenerator.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/StatsGenerator.java @@ -36,26 +36,7 @@ public CommandDescription generate( return EMPTY_DESCRIPTION; } StringBuilder cmd = new StringBuilder(" | stats "); - int nStats = randomIntBetween(1, 5); - for (int i = 0; i < nStats; i++) { - String name; - if (randomBoolean()) { - name = EsqlQueryGenerator.randomIdentifier(); - } else { - name = EsqlQueryGenerator.randomName(previousOutput); - if (name == null) { - name = EsqlQueryGenerator.randomIdentifier(); - } - } - String expression = EsqlQueryGenerator.agg(nonNull); - if (i > 0) { - cmd.append(","); - } - cmd.append(" "); - cmd.append(name); - cmd.append(" = "); - cmd.append(expression); - } + addStatsAggregations(previousOutput, nonNull, cmd); if (randomBoolean()) { var col = EsqlQueryGenerator.randomGroupableName(nonNull); if (col != null) { @@ -77,4 +58,33 @@ public ValidationResult validateOutput( // TODO validate columns return VALIDATION_OK; } + + public static void addStatsAggregations( + List previousOutput, + List nonNullColumns, + StringBuilder cmd + ) { + assert nonNullColumns.isEmpty() == false : "nonNullColumns should not be empty"; + + int nStats = randomIntBetween(1, 5); + for (int i = 0; i < nStats; i++) { + String name; + if (randomBoolean()) { + name = EsqlQueryGenerator.randomIdentifier(); + } else { + name = EsqlQueryGenerator.randomName(previousOutput); + if (name == null) { + name = EsqlQueryGenerator.randomIdentifier(); + } + } + String expression = EsqlQueryGenerator.agg(nonNullColumns); + if (i > 0) { + cmd.append(","); + } + cmd.append(" "); + cmd.append(name); + cmd.append(" = "); + cmd.append(expression); + } + } } diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/TopNStatsGenerator.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/TopNStatsGenerator.java new file mode 100644 index 0000000000000..e6f0b52cfac8c --- /dev/null +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/generative/command/pipe/TopNStatsGenerator.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe; + +import org.elasticsearch.xpack.esql.qa.rest.generative.EsqlQueryGenerator; +import org.elasticsearch.xpack.esql.qa.rest.generative.command.CommandGenerator; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.elasticsearch.test.ESTestCase.randomIntBetween; +import static org.elasticsearch.xpack.esql.qa.rest.generative.command.pipe.StatsGenerator.addStatsAggregations; + +public class TopNStatsGenerator implements CommandGenerator { + + public static final String TOP_N_STATS = "top_n_stats"; + public static final CommandGenerator INSTANCE = new TopNStatsGenerator(); + + @Override + public CommandDescription generate( + List previousCommands, + List previousOutput, + QuerySchema schema + ) { + List nonNull = previousOutput.stream() + .filter(EsqlQueryGenerator::fieldCanBeUsed) + .filter(x -> x.type().equals("null") == false) + .collect(Collectors.toList()); + if (nonNull.isEmpty()) { + return EMPTY_DESCRIPTION; + } + StringBuilder cmd = new StringBuilder(" | stats "); + addStatsAggregations(previousOutput, nonNull, cmd); + + var col = EsqlQueryGenerator.randomGroupableName(nonNull); + if (col == null) { + // Either a TopN + Stats with groupings, or nothing + return EMPTY_DESCRIPTION; + } + cmd.append(" by " + col); + cmd.append(" | sort " + col + " " + randomFrom("", " ASC", " DESC")); + cmd.append(" | limit " + randomIntBetween(1, 1000)); + + return new CommandDescription(TOP_N_STATS, this, cmd.toString(), Map.of()); + } + + @Override + public ValidationResult validateOutput( + List previousCommands, + CommandDescription commandDescription, + List previousColumns, + List> previousOutput, + List columns, + List> output + ) { + // TODO validate columns + return VALIDATION_OK; + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_n.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_n.csv-spec new file mode 100644 index 0000000000000..201fa732906f3 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats_top_n.csv-spec @@ -0,0 +1,57 @@ +count +from employees | stats c = count(height) | sort c | limit 100; + +c:long +100 +; + +countGrouping +from employees | stats c = count(height) by languages.long | sort languages.long | limit 100; + +c:long | languages.long:long +15 | 1 +19 | 2 +17 | 3 +18 | 4 +21 | 5 +10 | null +; + +multipleAggs +from employees | stats c = count(height), m = max(height) | sort c | limit 100; + +c:long | m:double +100 | 2.1 +; + +multipleAggsGrouping +from employees | stats c = count(height), m = max(height) by languages.long | sort languages.long | limit 100; + +c:long | m:double | languages.long:long +15 | 2.06 | 1 +19 | 2.1 | 2 +17 | 2.1 | 3 +18 | 2.0 | 4 +21 | 2.1 | 5 +10 | 2.1 | null +; + +keepGroupAfterStats +from employees | stats MAX(height) BY height, languages.long | sort height desc | limit 3 | KEEP height; + +height:double +2.1 +2.1 +2.1 +; + +multipleTopN +from employees +| stats c = count(height) by languages.long +| sort languages.long | limit 4 +| sort languages.long desc | limit 2; + +c:long | languages.long:long +17 | 3 +18 | 4 +; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index 39f37f952ae02..9c54438c2b203 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalPropagateEmptyRelation; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.LocalSubstituteSurrogateExpressions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceFieldWithConstantOrNull; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceTopNAggregateWithTopNAndAggregate; import org.elasticsearch.xpack.esql.optimizer.rules.logical.local.ReplaceTopNWithLimitAndSort; import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.rule.ParameterizedRuleExecutor; @@ -44,6 +45,7 @@ public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor( "Local rewrite", Limiter.ONCE, + new ReplaceTopNAggregateWithTopNAndAggregate(), new ReplaceTopNWithLimitAndSort(), new ReplaceFieldWithConstantOrNull(), new InferIsNotNull(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index ca117bfff34d6..2695a9e6d39dc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -54,6 +54,7 @@ import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceRowAsLocalRelation; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStatsFilteredAggWithEval; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceStringCasingWithInsensitiveEquals; +import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceTopNAndAggregateWithTopNAggregate; import org.elasticsearch.xpack.esql.optimizer.rules.logical.ReplaceTrivialTypeConversions; import org.elasticsearch.xpack.esql.optimizer.rules.logical.SetAsOptimized; import org.elasticsearch.xpack.esql.optimizer.rules.logical.SimplifyComparisonsArithmetics; @@ -208,6 +209,12 @@ protected static Batch operators(boolean local) { } protected static Batch cleanup() { - return new Batch<>("Clean Up", new ReplaceLimitAndSortAsTopN(), new ReplaceRowAsLocalRelation(), new PropgateUnmappedFields()); + return new Batch<>( + "Clean Up", + new ReplaceLimitAndSortAsTopN(), + new ReplaceTopNAndAggregateWithTopNAggregate(), + new ReplaceRowAsLocalRelation(), + new PropgateUnmappedFields() + ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceTopNAndAggregateWithTopNAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceTopNAndAggregateWithTopNAggregate.java new file mode 100644 index 0000000000000..e1a5741a04fed --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/ReplaceTopNAndAggregateWithTopNAggregate.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical; + +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate; +import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; +import org.elasticsearch.xpack.esql.rule.Rule; + +/** + * Looks for: + *
+ * {@link TopN}
+ * \_{@link Aggregate}
+ * 
+ * And replaces it with a {@link TopNAggregate}. + *

+ * {@link TimeSeriesAggregate} subclass should not appear here after a {@link TopN}. See {@link TranslateTimeSeriesAggregate}. + *

+ */ +public class ReplaceTopNAndAggregateWithTopNAggregate extends Rule { + + @Override + public LogicalPlan apply(LogicalPlan plan) { + return plan.transformUp(TopN.class, this::applyRule); + } + + private LogicalPlan applyRule(TopN topN) { + if (topN.child() instanceof Aggregate aggregate && aggregate instanceof TopNAggregate == false) { + assert aggregate.getClass() == Aggregate.class : "Only Aggregate can be replaced with TopNAggregate"; + + return new TopNAggregate( + aggregate.source(), + aggregate.child(), + aggregate.groupings(), + aggregate.aggregates(), + topN.order(), + topN.limit() + ); + } + return topN; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceTopNAggregateWithTopNAndAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceTopNAggregateWithTopNAndAggregate.java new file mode 100644 index 0000000000000..9b311048f94bc --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/local/ReplaceTopNAggregateWithTopNAndAggregate.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules.logical.local; + +import org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules; +import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; + +import static org.elasticsearch.xpack.esql.optimizer.rules.logical.OptimizerRules.TransformDirection.UP; + +/** + * Break TopNAggregate back into TopN + Aggregate to allow the order rules to kick in. + */ +public class ReplaceTopNAggregateWithTopNAndAggregate extends OptimizerRules.OptimizerRule { + public ReplaceTopNAggregateWithTopNAndAggregate() { + super(UP); + } + + @Override + protected LogicalPlan rule(TopNAggregate plan) { + return new TopN( + plan.source(), + new Aggregate(plan.source(), plan.child(), plan.groupings(), plan.aggregates()), + plan.order(), + plan.limit() + ); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java index 9312f2abdf509..3bf4bd519a3b3 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/PlanWritables.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Sample; import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; import org.elasticsearch.xpack.esql.plan.logical.join.InlineJoin; @@ -55,6 +56,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ShowExec; import org.elasticsearch.xpack.esql.plan.physical.SubqueryExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; +import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; @@ -95,7 +97,8 @@ public static List logical() { Rerank.ENTRY, Sample.ENTRY, TimeSeriesAggregate.ENTRY, - TopN.ENTRY + TopN.ENTRY, + TopNAggregate.ENTRY ); } @@ -125,7 +128,8 @@ public static List physical() { ShowExec.ENTRY, SubqueryExec.ENTRY, TimeSeriesAggregateExec.ENTRY, - TopNExec.ENTRY + TopNExec.ENTRY, + TopNAggregateExec.ENTRY ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java index 81a89950b0a02..e16a825ebc9fc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/QueryPlan.java @@ -89,6 +89,7 @@ public AttributeSet references() { /** * This very likely needs to be overridden for {@link QueryPlan#references} to be correct when inheriting. * This can be called on unresolved plans and therefore must not rely on calls to {@link QueryPlan#output()}. + * @see #references() */ protected AttributeSet computeReferences() { return Expressions.references(expressions()); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/TopNAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/TopNAggregate.java new file mode 100644 index 0000000000000..d4855a22b3710 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/TopNAggregate.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.esql.plan.logical; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.Order; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class TopNAggregate extends Aggregate { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + LogicalPlan.class, + "TopNAggregate", + TopNAggregate::new + ); + + private final List order; + private final Expression limit; + + protected List lazyOutput; + + public TopNAggregate( + Source source, + LogicalPlan child, + List groupings, + List aggregates, + List order, + Expression limit + ) { + super(source, child, groupings, aggregates); + this.order = order; + this.limit = limit; + } + + public TopNAggregate(StreamInput in) throws IOException { + super(in); + this.order = in.readCollectionAsList(Order::new); + this.limit = in.readNamedWriteable(Expression.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(order); + out.writeNamedWriteable(limit); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, TopNAggregate::new, child(), groupings, aggregates, order, limit); + } + + @Override + public TopNAggregate replaceChild(LogicalPlan newChild) { + return new TopNAggregate(source(), newChild, groupings, aggregates, order, limit); + } + + public List order() { + return order; + } + + public Expression limit() { + return limit; + } + + @Override + public boolean expressionsResolved() { + return super.expressionsResolved() && Resolvables.resolved(order) && limit.resolved(); + } + + @Override + public int hashCode() { + return Objects.hash(groupings, aggregates, order, limit, child()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + TopNAggregate other = (TopNAggregate) obj; + return Objects.equals(groupings, other.groupings) + && Objects.equals(aggregates, other.aggregates) + && Objects.equals(order, other.order) + && Objects.equals(limit, other.limit) + && Objects.equals(child(), other.child()); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TimeSeriesAggregateExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TimeSeriesAggregateExec.java index 23c4303fb6e0d..652fac5315ef8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TimeSeriesAggregateExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TimeSeriesAggregateExec.java @@ -98,7 +98,7 @@ public TimeSeriesAggregateExec replaceChild(PhysicalPlan newChild) { } @Override - public AggregateExec withAggregates(List newAggregates) { + public TimeSeriesAggregateExec withAggregates(List newAggregates) { return new TimeSeriesAggregateExec( source(), child(), @@ -126,7 +126,7 @@ public TimeSeriesAggregateExec withMode(AggregatorMode newMode) { } @Override - protected AggregateExec withEstimatedSize(int estimatedRowSize) { + protected TimeSeriesAggregateExec withEstimatedSize(int estimatedRowSize) { return new TimeSeriesAggregateExec( source(), child(), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TopNAggregateExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TopNAggregateExec.java new file mode 100644 index 0000000000000..85335dd373860 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/TopNAggregateExec.java @@ -0,0 +1,179 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.Order; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class TopNAggregateExec extends AggregateExec implements EstimatesRowSize { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + PhysicalPlan.class, + "TopNAggregateExec", + TopNAggregateExec::new + ); + + private final List order; + private final Expression limit; + + public TopNAggregateExec( + Source source, + PhysicalPlan child, + List groupings, + List aggregates, + AggregatorMode mode, + List intermediateAttributes, + Integer estimatedRowSize, + List order, + Expression limit + ) { + super(source, child, groupings, aggregates, mode, intermediateAttributes, estimatedRowSize); + this.order = order; + this.limit = limit; + } + + protected TopNAggregateExec(StreamInput in) throws IOException { + super(in); + this.order = in.readCollectionAsList(Order::new); + this.limit = in.readNamedWriteable(Expression.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(order); + out.writeNamedWriteable(limit); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + protected NodeInfo info() { + return NodeInfo.create( + this, + TopNAggregateExec::new, + child(), + groupings(), + aggregates(), + getMode(), + intermediateAttributes(), + estimatedRowSize(), + order, + limit + ); + } + + @Override + public TopNAggregateExec replaceChild(PhysicalPlan newChild) { + return new TopNAggregateExec( + source(), + newChild, + groupings(), + aggregates(), + getMode(), + intermediateAttributes(), + estimatedRowSize(), + order, + limit + ); + } + + public List order() { + return order; + } + + public Expression limit() { + return limit; + } + + @Override + public TopNAggregateExec withAggregates(List newAggregates) { + return new TopNAggregateExec( + source(), + child(), + groupings(), + newAggregates, + getMode(), + intermediateAttributes(), + estimatedRowSize(), + order, + limit + ); + } + + @Override + public TopNAggregateExec withMode(AggregatorMode newMode) { + return new TopNAggregateExec( + source(), + child(), + groupings(), + aggregates(), + newMode, + intermediateAttributes(), + estimatedRowSize(), + order, + limit + ); + } + + @Override + protected TopNAggregateExec withEstimatedSize(int estimatedRowSize) { + return new TopNAggregateExec( + source(), + child(), + groupings(), + aggregates(), + getMode(), + intermediateAttributes(), + estimatedRowSize, + order, + limit + ); + } + + @Override + public int hashCode() { + return Objects.hash(groupings(), aggregates(), getMode(), intermediateAttributes(), estimatedRowSize(), order, limit, child()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + TopNAggregateExec other = (TopNAggregateExec) obj; + return Objects.equals(groupings(), other.groupings()) + && Objects.equals(aggregates(), other.aggregates()) + && Objects.equals(getMode(), other.getMode()) + && Objects.equals(intermediateAttributes(), other.intermediateAttributes()) + && Objects.equals(estimatedRowSize(), other.estimatedRowSize()) + && Objects.equals(order, other.order) + && Objects.equals(limit, other.limit) + && Objects.equals(child(), other.child()); + } + +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java index e45fe2b0e81d8..d26d3d273bee5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AbstractPhysicalOperationProviders.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.esql.planner; +import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.compute.aggregation.Aggregator; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.AggregatorMode; @@ -18,22 +19,27 @@ import org.elasticsearch.compute.operator.EvalOperator; import org.elasticsearch.compute.operator.HashAggregationOperator.HashAggregationOperatorFactory; import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.core.Nullable; import org.elasticsearch.index.analysis.AnalysisRegistry; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.InvalidArgumentException; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.FoldContext; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.evaluator.EvalMapper; +import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.grouping.Categorize; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; +import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.LocalExecutionPlannerContext; import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner.PhysicalOperation; @@ -44,6 +50,7 @@ import java.util.function.Consumer; import static java.util.Collections.emptyList; +import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.stringToInt; public abstract class AbstractPhysicalOperationProviders implements PhysicalOperationProviders { @@ -98,6 +105,7 @@ public final PhysicalOperation groupingPhysicalOperation( // grouping List aggregatorFactories = new ArrayList<>(); List groupSpecs = new ArrayList<>(aggregateExec.groupings().size()); + AttributeMap attributesToTopNDef = buildAttributesToTopNDefMap(aggregateExec); for (Expression group : aggregateExec.groupings()) { Attribute groupAttribute = Expressions.attribute(group); // In case of `... BY groupAttribute = CATEGORIZE(sourceGroupAttribute)` the actual source attribute is different. @@ -143,7 +151,8 @@ else if (aggregatorMode.isOutputPartial()) { } layout.append(groupAttributeLayout); Layout.ChannelAndType groupInput = source.layout.get(sourceGroupAttribute.id()); - groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group)); + BlockHash.TopNDef topNDef = attributesToTopNDef.get(sourceGroupAttribute); + groupSpecs.add(new GroupSpec(groupInput == null ? null : groupInput.channel(), sourceGroupAttribute, group, topNDef)); } if (aggregatorMode == AggregatorMode.FINAL) { @@ -180,7 +189,8 @@ else if (aggregatorMode.isOutputPartial()) { aggregatorMode, aggregatorFactories, context.pageSize(aggregateExec.estimatedRowSize()), - analysisRegistry + analysisRegistry, + context.queryPragmas().maxTopNAggsLimit() ); } } @@ -190,6 +200,46 @@ else if (aggregatorMode.isOutputPartial()) { throw new EsqlIllegalArgumentException("no operator factory"); } + private AttributeMap buildAttributesToTopNDefMap(AggregateExec aggregateExec) { + if (aggregateExec instanceof TopNAggregateExec == false) { + return AttributeMap.emptyAttributeMap(); + } + + TopNAggregateExec topNAggregateExec = (TopNAggregateExec) aggregateExec; + List order = topNAggregateExec.order(); + Expression limit = topNAggregateExec.limit(); + + if (order.isEmpty() || limit == null || order.size() != aggregateExec.groupings().size()) { + return AttributeMap.emptyAttributeMap(); + } + + AttributeMap.Builder builder = AttributeMap.builder(order.size()); + + for (int i = 0; i < order.size(); i++) { + Order orderEntry = order.get(i); + + if ((orderEntry.child() instanceof Attribute) == false) { + throw new EsqlIllegalArgumentException("order by expression must be an attribute"); + } + if ((limit instanceof Literal) == false) { + throw new EsqlIllegalArgumentException("limit only supported with literal values"); + } + + Attribute attribute = (Attribute) orderEntry.child(); + int intLimit = stringToInt(BytesRefs.toString(((Literal) limit).value())); + + BlockHash.TopNDef topNDef = new BlockHash.TopNDef( + i, + orderEntry.direction().equals(Order.OrderDirection.ASC), + orderEntry.nullsPosition().equals(Order.NullsPosition.FIRST), + intLimit + ); + builder.put(attribute, topNDef); + } + + return builder.build(); + } + /*** * Creates a standard layout for intermediate aggregations, typically used across exchanges. * Puts the group first, followed by each aggregation. @@ -247,7 +297,6 @@ public static List intermediateAttributes(List channels, AggregatorMode mode) {} private void aggregatesToFactory( - List aggregates, AggregatorMode mode, Layout layout, @@ -338,16 +387,17 @@ private static AggregatorFunctionSupplier supplier(AggregateFunction aggregateFu * @param attribute The attribute, source of this group * @param expression The expression being used to group */ - private record GroupSpec(Integer channel, Attribute attribute, Expression expression) { + private record GroupSpec(Integer channel, Attribute attribute, Expression expression, @Nullable BlockHash.TopNDef topNDef) { BlockHash.GroupSpec toHashGroupSpec() { if (channel == null) { throw new EsqlIllegalArgumentException("planned to use ordinals but tried to use the hash instead"); } + return new BlockHash.GroupSpec( channel, elementType(), Alias.unwrap(expression) instanceof Categorize categorize ? categorize.categorizeDef() : null, - null + topNDef ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 7abe84d99e5f2..794e6f8105368 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -361,7 +361,8 @@ public Operator.OperatorFactory timeSeriesAggregatorOperatorFactory( groupSpecs, aggregatorMode, aggregatorFactories, - context.pageSize(ts.estimatedRowSize()) + context.pageSize(ts.estimatedRowSize()), + context.queryPragmas().maxTopNAggsLimit() ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java index 28204e2572842..9538b88cd6fa8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlanner.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.planner; -import org.apache.lucene.util.BytesRef; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.common.settings.Settings; @@ -117,6 +116,7 @@ import org.elasticsearch.xpack.esql.plan.physical.SampleExec; import org.elasticsearch.xpack.esql.plan.physical.ShowExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesSourceExec; +import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; @@ -350,7 +350,13 @@ private PhysicalOperation planRrfScoreEvalExec(RrfScoreEvalExec rrf, LocalExecut private PhysicalOperation planAggregation(AggregateExec aggregate, LocalExecutionPlannerContext context) { var source = plan(aggregate.child(), context); - return physicalOperationProviders.groupingPhysicalOperation(aggregate, source, context); + var physicalOperation = physicalOperationProviders.groupingPhysicalOperation(aggregate, source, context); + + if (aggregate instanceof TopNAggregateExec topNAggregate && topNAggregate.getMode().isOutputPartial() == false) { + return planTopN(topNAggregate.order(), topNAggregate.limit(), topNAggregate.estimatedRowSize(), physicalOperation, context); + } + + return physicalOperation; } private PhysicalOperation planEsQueryNode(EsQueryExec esQueryExec, LocalExecutionPlannerContext context) { @@ -475,9 +481,18 @@ private PhysicalOperation planParallelNode(ParallelExec parallelExec, LocalExecu } private PhysicalOperation planTopN(TopNExec topNExec, LocalExecutionPlannerContext context) { - final Integer rowSize = topNExec.estimatedRowSize(); - assert rowSize != null && rowSize > 0 : "estimated row size [" + rowSize + "] wasn't set"; PhysicalOperation source = plan(topNExec.child(), context); + return planTopN(topNExec.order(), topNExec.limit(), topNExec.estimatedRowSize(), source, context); + } + + private PhysicalOperation planTopN( + List order, + Expression limit, + Integer estimatedRowSize, + PhysicalOperation source, + LocalExecutionPlannerContext context + ) { + assert estimatedRowSize != null && estimatedRowSize > 0 : "estimated row size [" + estimatedRowSize + "] wasn't set"; ElementType[] elementTypes = new ElementType[source.layout.numberOfChannels()]; TopNEncoder[] encoders = new TopNEncoder[source.layout.numberOfChannels()]; @@ -496,9 +511,9 @@ private PhysicalOperation planTopN(TopNExec topNExec, LocalExecutionPlannerConte case PARTIAL_AGG, UNSUPPORTED -> TopNEncoder.UNSUPPORTED; }; } - List orders = topNExec.order().stream().map(order -> { + List orders = order.stream().map(orderEntry -> { int sortByChannel; - if (order.child() instanceof Attribute a) { + if (orderEntry.child() instanceof Attribute a) { sortByChannel = source.layout.get(a.id()).channel(); } else { throw new EsqlIllegalArgumentException("order by expression must be an attribute"); @@ -506,20 +521,21 @@ private PhysicalOperation planTopN(TopNExec topNExec, LocalExecutionPlannerConte return new TopNOperator.SortOrder( sortByChannel, - order.direction().equals(Order.OrderDirection.ASC), - order.nullsPosition().equals(Order.NullsPosition.FIRST) + orderEntry.direction().equals(Order.OrderDirection.ASC), + orderEntry.nullsPosition().equals(Order.NullsPosition.FIRST) ); }).toList(); - int limit; - if (topNExec.limit() instanceof Literal literal) { - Object val = literal.value() instanceof BytesRef br ? BytesRefs.toString(br) : literal.value(); - limit = stringToInt(val.toString()); + int intLimit; + if (limit instanceof Literal literal) { + String val = BytesRefs.toString(literal.value()); + intLimit = stringToInt(val); } else { throw new EsqlIllegalArgumentException("limit only supported with literal values"); } + return source.with( - new TopNOperatorFactory(limit, asList(elementTypes), asList(encoders), orders, context.pageSize(rowSize)), + new TopNOperatorFactory(intLimit, asList(elementTypes), asList(encoders), orders, context.pageSize(estimatedRowSize)), source.layout ); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java index 29f2db102ea7e..628e57ab07984 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/LocalMapper.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Sample; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; @@ -72,6 +73,11 @@ private PhysicalPlan mapUnary(UnaryPlan unary) { // Pipeline breakers // + if (unary instanceof TopNAggregate topNAggregate) { + List intermediate = MapperUtils.intermediateAttributes(topNAggregate); + return MapperUtils.topNAggExec(topNAggregate, mappedChild, AggregatorMode.INITIAL, intermediate); + } + if (unary instanceof Aggregate aggregate) { List intermediate = MapperUtils.intermediateAttributes(aggregate); return MapperUtils.aggExec(aggregate, mappedChild, AggregatorMode.INITIAL, intermediate); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java index 4d1d65d63932d..2394a91dbd4b7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/Mapper.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.esql.plan.logical.PipelineBreaker; import org.elasticsearch.xpack.esql.plan.logical.Sample; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; import org.elasticsearch.xpack.esql.plan.logical.join.Join; @@ -144,6 +145,27 @@ private PhysicalPlan mapUnary(UnaryPlan unary) { // // Pipeline breakers // + if (unary instanceof TopNAggregate aggregate) { + List intermediate = MapperUtils.intermediateAttributes(aggregate); + + // create both sides of the aggregate (for parallelism purposes), if no fragment is present + // TODO: might be easier long term to end up with just one node and split if necessary instead of doing that always at this + // stage + mappedChild = addExchangeForFragment(aggregate, mappedChild); + + // exchange was added - use the intermediates for the output + if (mappedChild instanceof ExchangeExec exchange) { + mappedChild = new ExchangeExec(mappedChild.source(), intermediate, true, exchange.child()); + } + // if no exchange was added (aggregation happening on the coordinator), create the initial agg + else { + mappedChild = MapperUtils.topNAggExec(aggregate, mappedChild, AggregatorMode.INITIAL, intermediate); + } + + // always add the final/reduction agg + return MapperUtils.topNAggExec(aggregate, mappedChild, AggregatorMode.FINAL, intermediate); + } + if (unary instanceof Aggregate aggregate) { List intermediate = MapperUtils.intermediateAttributes(aggregate); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java index 4851de1616844..d40ea8bdb1258 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/mapper/MapperUtils.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RrfScoreEval; import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; @@ -45,6 +46,7 @@ import org.elasticsearch.xpack.esql.plan.physical.RrfScoreEvalExec; import org.elasticsearch.xpack.esql.plan.physical.ShowExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; +import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.inference.CompletionExec; import org.elasticsearch.xpack.esql.plan.physical.inference.RerankExec; import org.elasticsearch.xpack.esql.planner.AbstractPhysicalOperationProviders; @@ -175,6 +177,25 @@ static AggregateExec aggExec(Aggregate aggregate, PhysicalPlan child, Aggregator } } + static TopNAggregateExec topNAggExec( + TopNAggregate aggregate, + PhysicalPlan child, + AggregatorMode aggMode, + List intermediateAttributes + ) { + return new TopNAggregateExec( + aggregate.source(), + child, + aggregate.groupings(), + aggregate.aggregates(), + aggMode, + intermediateAttributes, + null, + aggregate.order(), + aggregate.limit() + ); + } + static PhysicalPlan unsupported(LogicalPlan p) { throw new EsqlIllegalArgumentException("unsupported logical plan node [" + p.nodeName() + "]"); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java index 345bf3b8767ef..b257a843832ea 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/QueryPragmas.java @@ -79,6 +79,8 @@ public final class QueryPragmas implements Writeable { MappedFieldType.FieldExtractPreference.NONE ); + public static final Setting MAX_TOP_N_AGGS_LIMIT = Setting.intSetting("max_top_n_aggs_limit", 0, 0); + public static final QueryPragmas EMPTY = new QueryPragmas(Settings.EMPTY); private final Settings settings; @@ -196,6 +198,14 @@ public MappedFieldType.FieldExtractPreference fieldExtractPreference() { return FIELD_EXTRACT_PREFERENCE.get(settings); } + /** + * The maximum {@code LIMIT} after an aggregation with groups to enable + * the {@link org.elasticsearch.xpack.esql.plan.logical.TopNAggregate} behavior. + */ + public int maxTopNAggsLimit() { + return MAX_TOP_N_AGGS_LIMIT.get(settings); + } + public boolean isEmpty() { return settings.isEmpty(); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index a604e1d26d313..7f8ba6aeb0b89 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -94,6 +94,7 @@ import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TimeSeriesSourceExec; +import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.planner.FilterTests; import org.elasticsearch.xpack.esql.plugin.EsqlFlags; @@ -136,6 +137,8 @@ import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.defaultLookupResolution; import static org.elasticsearch.xpack.esql.analysis.AnalyzerTestUtils.indexWithDateDateNanosUnionType; +import static org.elasticsearch.xpack.esql.core.expression.Expressions.name; +import static org.elasticsearch.xpack.esql.core.expression.Expressions.names; import static org.elasticsearch.xpack.esql.core.querydsl.query.Query.unscore; import static org.elasticsearch.xpack.esql.core.type.DataType.DATE_NANOS; import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; @@ -2369,6 +2372,44 @@ public void testToDateNanosPushDown() { assertThat(expected.toString(), is(esQuery.query().toString())); } + public void testTopNAggregate() { + var stats = EsqlTestUtils.statsForExistingField("first_name", "last_name"); + var plan = plannerOptimizer.plan(""" + from test + | stats x = count(first_name) by first_name, last_name + | sort x DESC, first_name NULLS LAST + | LIMIT 5 + """, stats); + + var aggregate1 = as(plan, TopNAggregateExec.class); + var exchange = as(aggregate1.child(), ExchangeExec.class); + var aggregate2 = as(exchange.child(), TopNAggregateExec.class); + + var extract = as(aggregate2.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), equalTo(List.of("first_name", "last_name"))); + var esQuery = as(extract.child(), EsQueryExec.class); + + assertThat(aggregate1.groupings(), hasSize(2)); + assertThat(aggregate1.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST + KEYWORD_EST)); + assertThat(aggregate1.order(), hasSize(2)); + var order1 = aggregate1.order().get(0); + assertThat(name(order1.child()), equalTo("x")); + assertThat(order1.direction(), equalTo(Order.OrderDirection.DESC)); + assertThat(order1.nullsPosition(), equalTo(Order.NullsPosition.FIRST)); + var order2 = aggregate1.order().get(1); + assertThat(name(order2.child()), equalTo("first_name")); + assertThat(order2.direction(), equalTo(Order.OrderDirection.ASC)); + assertThat(order2.nullsPosition(), equalTo(Order.NullsPosition.LAST)); + assertThat(aggregate1.limit().fold(FoldContext.small()), equalTo(5)); + + // Check that both agg nodes are identical + assertThat(aggregate1.aggregates(), equalTo(aggregate2.aggregates())); + assertThat(aggregate1.groupings(), equalTo(aggregate2.groupings())); + assertThat(aggregate1.estimatedRowSize(), equalTo(aggregate2.estimatedRowSize())); + assertThat(aggregate1.order(), equalTo(aggregate2.order())); + assertThat(aggregate1.limit(), equalTo(aggregate2.limit())); + } + public void testVerifierOnMissingReferences() throws Exception { PhysicalPlan plan = plannerOptimizer.plan(""" diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index e301c1610bd7b..44f2c17afe942 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -121,6 +121,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Sample; import org.elasticsearch.xpack.esql.plan.logical.TimeSeriesAggregate; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.inference.Completion; import org.elasticsearch.xpack.esql.plan.logical.inference.Rerank; @@ -420,9 +421,8 @@ public void testCombineProjectionWithAggregationAndEval() { /** * Expects - * TopN[[Order[x{r}#10,ASC,LAST]],1000[INTEGER]] - * \_Aggregate[[languages{f}#16],[MAX(emp_no{f}#13) AS x, languages{f}#16]] - * \_EsRelation[test][_meta_field{f}#19, emp_no{f}#13, first_name{f}#14, ..] + * TopNAggregate[[languages{f}#16],[MAX(emp_no{f}#13) AS x, languages{f}#16], [Order[x{r}#10,ASC,LAST]],1000[INTEGER]] + * \_EsRelation[test][_meta_field{f}#19, emp_no{f}#13, first_name{f}#14, ..] */ public void testRemoveOverridesInAggregate() throws Exception { var plan = plan(""" @@ -431,9 +431,8 @@ public void testRemoveOverridesInAggregate() throws Exception { | sort x """); - var topN = as(plan, TopN.class); - var agg = as(topN.child(), Aggregate.class); - var aggregates = agg.aggregates(); + var topNAgg = as(plan, TopNAggregate.class); + var aggregates = topNAgg.aggregates(); assertThat(aggregates, hasSize(2)); assertThat(Expressions.names(aggregates), contains("x", "languages")); var alias = as(aggregates.get(0), Alias.class); @@ -450,9 +449,8 @@ public void testRemoveOverridesInAggregate() throws Exception { /** * Expects - * TopN[[Order[b{r}#10,ASC,LAST]],1000[INTEGER]] - * \_Aggregate[[b{r}#10],[languages{f}#16 AS b]] - * \_EsRelation[test][_meta_field{f}#19, emp_no{f}#13, first_name{f}#14, ..] + * TopNAggregate[[b{r}#10],[languages{f}#16 AS b],[Order[b{r}#10,ASC,LAST]],1000[INTEGER]] + * \_EsRelation[test][_meta_field{f}#19, emp_no{f}#13, first_name{f}#14, ..] */ public void testAggsWithOverridingInputAndGrouping() throws Exception { var plan = plan(""" @@ -461,9 +459,8 @@ public void testAggsWithOverridingInputAndGrouping() throws Exception { | sort b """); - var topN = as(plan, TopN.class); - var agg = as(topN.child(), Aggregate.class); - var aggregates = agg.aggregates(); + var topNAgg = as(plan, TopNAggregate.class); + var aggregates = topNAgg.aggregates(); assertThat(aggregates, hasSize(1)); assertThat(Expressions.names(aggregates), contains("b")); assertWarnings( @@ -6771,24 +6768,30 @@ public void testTranslateMixedAggsWithMathWithoutGrouping() { assertThat(add.right().fold(FoldContext.small()), equalTo(0.2)); } + /** + * TopNAggregate[[cluster{r}#7563],[SUM(sum(rate(network.total_bytes_in)){r}#7574,true[BOOLEAN]) AS sum(rate(network.total_bytes + * _in))#7560, cluster{r}#7563],[Order[cluster{f}#7563,ASC,LAST]],10[INTEGER]] + * \_TimeSeriesAggregate[[_tsid{m}#7575],[RATE(network.total_bytes_in{f}#7570,true[BOOLEAN],@timestamp{f}#7562) AS sum(rate(network.tota + * l_bytes_in))#7574, VALUES(cluster{f}#7563,true[BOOLEAN]) AS cluster#7563],null] + * \_EsRelation[k8s][TIME_SERIES][@timestamp{f}#7562, client.ip{f}#7566, cluster{f}#7..] + */ public void testTranslateMetricsGroupedByOneDimension() { assumeTrue("requires snapshot builds", Build.current().isSnapshot()); var query = "TS k8s | STATS sum(rate(network.total_bytes_in)) BY cluster | SORT cluster | LIMIT 10"; var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query, EsqlTestUtils.TEST_CFG))); - TopN topN = as(plan, TopN.class); - Aggregate aggsByCluster = as(topN.child(), Aggregate.class); - assertThat(aggsByCluster, not(instanceOf(TimeSeriesAggregate.class))); - assertThat(aggsByCluster.aggregates(), hasSize(2)); - TimeSeriesAggregate aggsByTsid = as(aggsByCluster.child(), TimeSeriesAggregate.class); + TopNAggregate topNAggsByCluster = as(plan, TopNAggregate.class); + assertThat(topNAggsByCluster, not(instanceOf(TimeSeriesAggregate.class))); + assertThat(topNAggsByCluster.aggregates(), hasSize(2)); + TimeSeriesAggregate aggsByTsid = as(topNAggsByCluster.child(), TimeSeriesAggregate.class); assertThat(aggsByTsid.aggregates(), hasSize(2)); // _tsid is dropped assertNull(aggsByTsid.timeBucket()); EsRelation relation = as(aggsByTsid.child(), EsRelation.class); assertThat(relation.indexMode(), equalTo(IndexMode.TIME_SERIES)); - Sum sum = as(Alias.unwrap(aggsByCluster.aggregates().get(0)), Sum.class); + Sum sum = as(Alias.unwrap(topNAggsByCluster.aggregates().get(0)), Sum.class); assertThat(Expressions.attribute(sum.field()).id(), equalTo(aggsByTsid.aggregates().get(0).id())); - assertThat(aggsByCluster.groupings(), hasSize(1)); - assertThat(Expressions.attribute(aggsByCluster.groupings().get(0)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); + assertThat(topNAggsByCluster.groupings(), hasSize(1)); + assertThat(Expressions.attribute(topNAggsByCluster.groupings().get(0)).id(), equalTo(aggsByTsid.aggregates().get(1).id())); Rate rate = as(Alias.unwrap(aggsByTsid.aggregates().get(0)), Rate.class); assertThat(Expressions.attribute(rate.field()).name(), equalTo("network.total_bytes_in")); @@ -6916,8 +6919,7 @@ public void testTranslateSumOfTwoRates() { | LIMIT 10 """; var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query, EsqlTestUtils.TEST_CFG))); - TopN topN = as(plan, TopN.class); - Aggregate finalAgg = as(topN.child(), Aggregate.class); + TopNAggregate finalAgg = as(plan, TopNAggregate.class); Eval eval = as(finalAgg.child(), Eval.class); assertThat(eval.fields(), hasSize(1)); Add sum = as(Alias.unwrap(eval.fields().get(0)), Add.class); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index a609a1e494e54..2724c8ed12663 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -107,6 +107,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregate; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinTypes; import org.elasticsearch.xpack.esql.plan.logical.local.EmptyLocalSupplier; @@ -131,6 +132,7 @@ import org.elasticsearch.xpack.esql.plan.physical.MvExpandExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.plan.physical.ProjectExec; +import org.elasticsearch.xpack.esql.plan.physical.TopNAggregateExec; import org.elasticsearch.xpack.esql.plan.physical.TopNExec; import org.elasticsearch.xpack.esql.plan.physical.UnaryExec; import org.elasticsearch.xpack.esql.planner.EsPhysicalOperationProviders; @@ -811,6 +813,32 @@ public void testDoExtractGroupingFields() { assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 2 + 50)); } + public void testDoNotExtractGroupingFieldsTopN() { + var plan = physicalPlan(""" + from test + | stats x = sum(salary) by first_name + | sort first_name + """); + + var optimized = optimizedPlan(plan); + var aggregate = as(optimized, TopNAggregateExec.class); + assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST)); + assertThat(aggregate.groupings(), hasSize(1)); + + var exchange = asRemoteExchange(aggregate.child()); + aggregate = as(exchange.child(), TopNAggregateExec.class); + assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST)); + assertThat(aggregate.groupings(), hasSize(1)); + + var extract = as(aggregate.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), equalTo(List.of("salary"))); + + var source = source(extract.child()); + // doc id and salary are ints. salary isn't extracted. + // TODO salary kind of is extracted. At least sometimes it is. should it count? + assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES * 2)); + } + public void testExtractGroupingFieldsIfAggd() { var plan = physicalPlan(""" from test @@ -835,6 +863,30 @@ public void testExtractGroupingFieldsIfAggd() { assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES + KEYWORD_EST)); } + public void testExtractGroupingFieldsIfAggdTopN() { + var plan = physicalPlan(""" + from test + | stats x = count(first_name) by first_name + | sort first_name + """); + + var optimized = optimizedPlan(plan); + var aggregate = as(optimized, TopNAggregateExec.class); + assertThat(aggregate.groupings(), hasSize(1)); + assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST)); + + var exchange = asRemoteExchange(aggregate.child()); + aggregate = as(exchange.child(), TopNAggregateExec.class); + assertThat(aggregate.groupings(), hasSize(1)); + assertThat(aggregate.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST)); + + var extract = as(aggregate.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), equalTo(List.of("first_name"))); + + var source = source(extract.child()); + assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES + KEYWORD_EST)); + } + public void testExtractGroupingFieldsIfAggdWithEval() { var plan = physicalPlan(""" from test @@ -5578,12 +5630,11 @@ public void testPushSpatialDistanceEvalWithStatsToSource() { | SORT count DESC, country ASC """; var plan = this.physicalPlan(query, airports); - var topN = as(plan, TopNExec.class); - var agg = as(topN.child(), AggregateExec.class); - var exchange = as(agg.child(), ExchangeExec.class); + var topNAgg = as(plan, TopNAggregateExec.class); + var exchange = as(topNAgg.child(), ExchangeExec.class); var fragment = as(exchange.child(), FragmentExec.class); - var agg2 = as(fragment.fragment(), Aggregate.class); - var filter = as(agg2.child(), Filter.class); + var topNAgg2 = as(fragment.fragment(), TopNAggregate.class); + var filter = as(topNAgg2.child(), Filter.class); // Validate the filter condition (two distance filters) var and = as(filter.condition(), And.class); @@ -5603,12 +5654,11 @@ public void testPushSpatialDistanceEvalWithStatsToSource() { // Now optimize the plan var optimized = optimizedPlan(plan); - var topLimit = as(optimized, TopNExec.class); - var aggExec = as(topLimit.child(), AggregateExec.class); - var exchangeExec = as(aggExec.child(), ExchangeExec.class); - var aggExec2 = as(exchangeExec.child(), AggregateExec.class); + var topNAggExec = as(optimized, TopNAggregateExec.class); + var exchangeExec = as(topNAggExec.child(), ExchangeExec.class); + var topNAggExec2 = as(exchangeExec.child(), TopNAggregateExec.class); // TODO: Remove the eval entirely, since the distance is no longer required after filter pushdown - var extract = as(aggExec2.child(), FieldExtractExec.class); + var extract = as(topNAggExec2.child(), FieldExtractExec.class); var evalExec = as(extract.child(), EvalExec.class); var stDistance = as(evalExec.fields().get(0).child(), StDistance.class); assertThat("Expect distance function to expect doc-values", stDistance.leftDocValues(), is(true)); @@ -8098,6 +8148,47 @@ public void testSamplePushDown() { assertThat(randomSampling.hash(), equalTo(0)); } + public void testTopNStats() { + var plan = physicalPlan(""" + from test + | stats x = count(first_name) by first_name, last_name + | sort x DESC, first_name NULLS LAST + | LIMIT 5 + """); + + var optimized = optimizedPlan(plan); + var aggregate1 = as(optimized, TopNAggregateExec.class); + + var exchange = asRemoteExchange(aggregate1.child()); + var aggregate2 = as(exchange.child(), TopNAggregateExec.class); + + var extract = as(aggregate2.child(), FieldExtractExec.class); + assertThat(names(extract.attributesToExtract()), equalTo(List.of("first_name", "last_name"))); + + var source = source(extract.child()); + assertThat(source.estimatedRowSize(), equalTo(Integer.BYTES + KEYWORD_EST + KEYWORD_EST)); + + assertThat(aggregate1.groupings(), hasSize(2)); + assertThat(aggregate1.estimatedRowSize(), equalTo(Long.BYTES + KEYWORD_EST + KEYWORD_EST)); + assertThat(aggregate1.order(), hasSize(2)); + var order1 = aggregate1.order().get(0); + assertThat(name(order1.child()), equalTo("x")); + assertThat(order1.direction(), equalTo(Order.OrderDirection.DESC)); + assertThat(order1.nullsPosition(), equalTo(Order.NullsPosition.FIRST)); + var order2 = aggregate1.order().get(1); + assertThat(name(order2.child()), equalTo("first_name")); + assertThat(order2.direction(), equalTo(Order.OrderDirection.ASC)); + assertThat(order2.nullsPosition(), equalTo(Order.NullsPosition.LAST)); + assertThat(aggregate1.limit().fold(FoldContext.small()), equalTo(5)); + + // Check that both agg nodes are identical + assertThat(aggregate1.aggregates(), equalTo(aggregate2.aggregates())); + assertThat(aggregate1.groupings(), equalTo(aggregate2.groupings())); + assertThat(aggregate1.estimatedRowSize(), equalTo(aggregate2.estimatedRowSize())); + assertThat(aggregate1.order(), equalTo(aggregate2.order())); + assertThat(aggregate1.limit(), equalTo(aggregate2.limit())); + } + @SuppressWarnings("SameParameterValue") private static void assertFilterCondition( Filter filter, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/TopNAggregateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/TopNAggregateSerializationTests.java new file mode 100644 index 0000000000000..33bbc628cf3d4 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/TopNAggregateSerializationTests.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.logical; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class TopNAggregateSerializationTests extends AbstractLogicalPlanSerializationTests { + @Override + protected TopNAggregate createTestInstance() { + Source source = randomSource(); + LogicalPlan child = randomChild(0); + List groupings = randomFieldAttributes(0, 5, false).stream().map(a -> (Expression) a).toList(); + List aggregates = AggregateSerializationTests.randomAggregates(); + List order = randomOrder(); + Expression limit = FieldAttributeTests.createFieldAttribute(1, true); + + return new TopNAggregate(source, child, groupings, aggregates, order, limit); + } + + public static List randomOrder() { + int size = between(1, 5); + List result = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + Expression field = FieldAttributeTests.createFieldAttribute(1, true); + Order.OrderDirection direction = randomFrom(Order.OrderDirection.values()); + Order.NullsPosition nullsPosition = randomFrom(Order.NullsPosition.values()); + result.add(new Order(randomSource(), field, direction, nullsPosition)); + } + return result; + } + + @Override + protected TopNAggregate mutateInstance(TopNAggregate instance) throws IOException { + LogicalPlan child = instance.child(); + List groupings = instance.groupings(); + List aggregates = instance.aggregates(); + List order = instance.order(); + Expression limit = instance.limit(); + switch (between(0, 4)) { + case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); + case 1 -> groupings = randomValueOtherThan( + groupings, + () -> randomFieldAttributes(0, 5, false).stream().map(a -> (Expression) a).toList() + ); + case 2 -> aggregates = randomValueOtherThan(aggregates, AggregateSerializationTests::randomAggregates); + case 3 -> order = randomValueOtherThan(order, TopNAggregateSerializationTests::randomOrder); + case 4 -> limit = randomValueOtherThan(limit, () -> FieldAttributeTests.createFieldAttribute(1, true)); + } + return new TopNAggregate(instance.source(), child, groupings, aggregates, order, limit); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/TopNAggregateExecSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/TopNAggregateExecSerializationTests.java new file mode 100644 index 0000000000000..7a56d433281ef --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/physical/TopNAggregateExecSerializationTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plan.physical; + +import org.elasticsearch.compute.aggregation.AggregatorMode; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.NamedExpression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.Order; +import org.elasticsearch.xpack.esql.expression.function.FieldAttributeTests; +import org.elasticsearch.xpack.esql.plan.logical.AggregateSerializationTests; +import org.elasticsearch.xpack.esql.plan.logical.TopNAggregateSerializationTests; + +import java.io.IOException; +import java.util.List; + +public class TopNAggregateExecSerializationTests extends AbstractPhysicalPlanSerializationTests { + @Override + protected TopNAggregateExec createTestInstance() { + Source source = randomSource(); + PhysicalPlan child = randomChild(0); + List groupings = randomFieldAttributes(0, 5, false).stream().map(a -> (Expression) a).toList(); + List aggregates = AggregateSerializationTests.randomAggregates(); + AggregatorMode mode = randomFrom(AggregatorMode.values()); + List intermediateAttributes = randomFieldAttributes(0, 5, false); + Integer estimatedRowSize = randomEstimatedRowSize(); + List order = TopNAggregateSerializationTests.randomOrder(); + Expression limit = FieldAttributeTests.createFieldAttribute(1, true); + + return new TopNAggregateExec(source, child, groupings, aggregates, mode, intermediateAttributes, estimatedRowSize, order, limit); + } + + @Override + protected TopNAggregateExec mutateInstance(TopNAggregateExec instance) throws IOException { + PhysicalPlan child = instance.child(); + List groupings = instance.groupings(); + List aggregates = instance.aggregates(); + List intermediateAttributes = instance.intermediateAttributes(); + AggregatorMode mode = instance.getMode(); + Integer estimatedRowSize = instance.estimatedRowSize(); + List order = instance.order(); + Expression limit = instance.limit(); + switch (between(0, 7)) { + case 0 -> child = randomValueOtherThan(child, () -> randomChild(0)); + case 1 -> groupings = randomValueOtherThan(groupings, () -> randomFieldAttributes(0, 5, false)); + case 2 -> aggregates = randomValueOtherThan(aggregates, AggregateSerializationTests::randomAggregates); + case 3 -> mode = randomValueOtherThan(mode, () -> randomFrom(AggregatorMode.values())); + case 4 -> intermediateAttributes = randomValueOtherThan(intermediateAttributes, () -> randomFieldAttributes(0, 5, false)); + case 5 -> estimatedRowSize = randomValueOtherThan( + estimatedRowSize, + AbstractPhysicalPlanSerializationTests::randomEstimatedRowSize + ); + case 6 -> { + order = randomValueOtherThan(order, TopNAggregateSerializationTests::randomOrder); + } + case 7 -> { + limit = FieldAttributeTests.createFieldAttribute(1, true); + } + default -> throw new IllegalStateException(); + } + return new TopNAggregateExec( + instance.source(), + child, + groupings, + aggregates, + mode, + intermediateAttributes, + estimatedRowSize, + order, + limit + ); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java index ff15f3cc1e4ba..cb9933f20a153 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/TestPhysicalOperationProviders.java @@ -144,7 +144,8 @@ public Operator.OperatorFactory timeSeriesAggregatorOperatorFactory( groupSpecs, aggregatorMode, aggregatorFactories, - context.pageSize(ts.estimatedRowSize()) + context.pageSize(ts.estimatedRowSize()), + context.queryPragmas().maxTopNAggsLimit() ); }