diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java index 652defa7b39cd..259267c8970b1 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/compute/operator/AggregatorBenchmark.java @@ -207,7 +207,7 @@ private static AggregatorFunctionSupplier supplier(String op, String dataType, S default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]"); }; case SUM -> switch (dataType) { - case LONGS -> new SumLongAggregatorFunctionSupplier(List.of(dataChannel)); + case LONGS -> new SumLongAggregatorFunctionSupplier(-1, -2, "", List.of(dataChannel)); case DOUBLES -> new SumDoubleAggregatorFunctionSupplier(List.of(dataChannel)); default -> throw new IllegalArgumentException("unsupported data type [" + dataType + "]"); }; diff --git a/docs/changelog/116170.yaml b/docs/changelog/116170.yaml new file mode 100644 index 0000000000000..7190fb70e9522 --- /dev/null +++ b/docs/changelog/116170.yaml @@ -0,0 +1,6 @@ +pr: 116170 +summary: "ESQL: Prevent overflow on SUM using multiple aggregators" +area: ES|QL +type: bug +issues: + - 110443 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2acf80e426c82..2981e3147b517 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -187,6 +187,7 @@ static TransportVersion def(int id) { public static final TransportVersion QUERY_RULES_RETRIEVER = def(8_782_00_0); public static final TransportVersion ESQL_CCS_EXEC_INFO_WITH_FAILURES = def(8_783_00_0); public static final TransportVersion LOGSDB_TELEMETRY = def(8_784_00_0); + public static final TransportVersion ESQL_CONFIGURATION_WITH_FEATURES = def(8_785_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index 5bfcd54e963b3..af2320ac39857 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -442,7 +442,7 @@ private static void setTestSysProps(Random random) { } protected final Logger logger = LogManager.getLogger(getClass()); - private ThreadContext threadContext; + protected ThreadContext threadContext; // ----------------------------------------------------------------- // Suite and test case setup/cleanup. diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java index f11ccbced6fbe..74849091a63d5 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorFunctionSupplierImplementer.java @@ -33,6 +33,7 @@ import static org.elasticsearch.compute.gen.Types.DRIVER_CONTEXT; import static org.elasticsearch.compute.gen.Types.LIST_INTEGER; import static org.elasticsearch.compute.gen.Types.STRING; +import static org.elasticsearch.compute.gen.Types.WARNINGS; /** * Implements "AggregationFunctionSupplier" from a class annotated with both @@ -139,8 +140,9 @@ private MethodSpec aggregator() { if (hasWarnings) { builder.addStatement( - "var warnings = Warnings.createWarnings(driverContext.warningsMode(), " - + "warningsLineNumber, warningsColumnNumber, warningsSourceText)" + "var warnings = $T.createWarnings(driverContext.warningsMode(), " + + "warningsLineNumber, warningsColumnNumber, warningsSourceText)", + WARNINGS ); } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunction.java new file mode 100644 index 0000000000000..87c9f073653b3 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunction.java @@ -0,0 +1,181 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link OverflowingSumLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class OverflowingSumLongAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sum", ElementType.LONG), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final DriverContext driverContext; + + private final LongState state; + + private final List channels; + + public OverflowingSumLongAggregatorFunction(DriverContext driverContext, List channels, + LongState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static OverflowingSumLongAggregatorFunction create(DriverContext driverContext, + List channels) { + return new OverflowingSumLongAggregatorFunction(driverContext, channels, new LongState(OverflowingSumLongAggregator.init())); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page, BooleanVector mask) { + if (mask.allFalse()) { + // Entire page masked away + return; + } + if (mask.allTrue()) { + // No masking + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + return; + } + // Some positions masked away, others kept + LongBlock block = page.getBlock(channels.get(0)); + LongVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector, mask); + } else { + addRawBlock(block, mask); + } + } + + private void addRawVector(LongVector vector) { + state.seen(true); + for (int i = 0; i < vector.getPositionCount(); i++) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), vector.getLong(i))); + } + } + + private void addRawVector(LongVector vector, BooleanVector mask) { + state.seen(true); + for (int i = 0; i < vector.getPositionCount(); i++) { + if (mask.getBoolean(i) == false) { + continue; + } + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), vector.getLong(i))); + } + } + + private void addRawBlock(LongBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + state.seen(true); + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), block.getLong(i))); + } + } + } + + private void addRawBlock(LongBlock block, BooleanVector mask) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (mask.getBoolean(p) == false) { + continue; + } + if (block.isNull(p)) { + continue; + } + state.seen(true); + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), block.getLong(i))); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + assert sum.getPositionCount() == 1; + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert seen.getPositionCount() == 1; + if (seen.getBoolean(0)) { + state.longValue(OverflowingSumLongAggregator.combine(state.longValue(), sum.getLong(0))); + state.seen(true); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + if (state.seen() == false) { + blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1); + return; + } + blocks[offset] = driverContext.blockFactory().newConstantLongBlockWith(state.longValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..8ba67ff08568c --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionSupplier.java @@ -0,0 +1,39 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link OverflowingSumLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class OverflowingSumLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public OverflowingSumLongAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public OverflowingSumLongAggregatorFunction aggregator(DriverContext driverContext) { + return OverflowingSumLongAggregatorFunction.create(driverContext, channels); + } + + @Override + public OverflowingSumLongGroupingAggregatorFunction groupingAggregator( + DriverContext driverContext) { + return OverflowingSumLongGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "overflowing_sum of longs"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..fd2a2895a25c6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunction.java @@ -0,0 +1,221 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.LongVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link OverflowingSumLongAggregator}. + * This class is generated. Do not edit it. + */ +public final class OverflowingSumLongGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("sum", ElementType.LONG), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final LongArrayState state; + + private final List channels; + + private final DriverContext driverContext; + + public OverflowingSumLongGroupingAggregatorFunction(List channels, LongArrayState state, + DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static OverflowingSumLongGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new OverflowingSumLongGroupingAggregatorFunction(channels, new LongArrayState(driverContext.bigArrays(), OverflowingSumLongAggregator.init()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + LongBlock valuesBlock = page.getBlock(channels.get(0)); + LongVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void close() { + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void close() { + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, LongVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = groups.getInt(g); + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } + } + } + + @Override + public void selectedMayContainUnseenGroups(SeenGroupIds seenGroupIds) { + state.enableGroupIdTracking(seenGroupIds); + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block sumUncast = page.getBlock(channels.get(0)); + if (sumUncast.areAllValuesNull()) { + return; + } + LongVector sum = ((LongBlock) sumUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = groups.getInt(groupPosition); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + LongArrayState inState = ((OverflowingSumLongGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + if (inState.hasValue(position)) { + state.set(groupId, OverflowingSumLongAggregator.combine(state.getOrDefault(groupId), inState.get(position))); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = state.toValuesBlock(selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java index fac21d99bf713..fbe862d94857d 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunction.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.compute.aggregation; +import java.lang.ArithmeticException; import java.lang.Integer; import java.lang.Override; import java.lang.String; @@ -17,6 +18,7 @@ import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Warnings; /** * {@link AggregatorFunction} implementation for {@link SumLongAggregator}. @@ -25,24 +27,28 @@ public final class SumLongAggregatorFunction implements AggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("sum", ElementType.LONG), - new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + new IntermediateStateDesc("seen", ElementType.BOOLEAN), + new IntermediateStateDesc("failed", ElementType.BOOLEAN) ); + + private final Warnings warnings; private final DriverContext driverContext; - private final LongState state; + private final LongFallibleState state; private final List channels; - public SumLongAggregatorFunction(DriverContext driverContext, List channels, - LongState state) { + public SumLongAggregatorFunction(Warnings warnings, DriverContext driverContext, + List channels, LongFallibleState state) { + this.warnings = warnings; this.driverContext = driverContext; this.channels = channels; this.state = state; } - public static SumLongAggregatorFunction create(DriverContext driverContext, + public static SumLongAggregatorFunction create(Warnings warnings, DriverContext driverContext, List channels) { - return new SumLongAggregatorFunction(driverContext, channels, new LongState(SumLongAggregator.init())); + return new SumLongAggregatorFunction(warnings, driverContext, channels, new LongFallibleState(SumLongAggregator.init())); } public static List intermediateStateDesc() { @@ -56,6 +62,9 @@ public int intermediateBlockCount() { @Override public void addRawInput(Page page, BooleanVector mask) { + if (state.failed()) { + return; + } if (mask.allFalse()) { // Entire page masked away return; @@ -84,7 +93,13 @@ public void addRawInput(Page page, BooleanVector mask) { private void addRawVector(LongVector vector) { state.seen(true); for (int i = 0; i < vector.getPositionCount(); i++) { - state.longValue(SumLongAggregator.combine(state.longValue(), vector.getLong(i))); + try { + state.longValue(SumLongAggregator.combine(state.longValue(), vector.getLong(i))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + return; + } } } @@ -94,7 +109,13 @@ private void addRawVector(LongVector vector, BooleanVector mask) { if (mask.getBoolean(i) == false) { continue; } - state.longValue(SumLongAggregator.combine(state.longValue(), vector.getLong(i))); + try { + state.longValue(SumLongAggregator.combine(state.longValue(), vector.getLong(i))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + return; + } } } @@ -107,7 +128,13 @@ private void addRawBlock(LongBlock block) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - state.longValue(SumLongAggregator.combine(state.longValue(), block.getLong(i))); + try { + state.longValue(SumLongAggregator.combine(state.longValue(), block.getLong(i))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + return; + } } } } @@ -124,7 +151,13 @@ private void addRawBlock(LongBlock block, BooleanVector mask) { int start = block.getFirstValueIndex(p); int end = start + block.getValueCount(p); for (int i = start; i < end; i++) { - state.longValue(SumLongAggregator.combine(state.longValue(), block.getLong(i))); + try { + state.longValue(SumLongAggregator.combine(state.longValue(), block.getLong(i))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + return; + } } } } @@ -145,9 +178,23 @@ public void addIntermediateInput(Page page) { } BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); assert seen.getPositionCount() == 1; - if (seen.getBoolean(0)) { - state.longValue(SumLongAggregator.combine(state.longValue(), sum.getLong(0))); + Block failedUncast = page.getBlock(channels.get(2)); + if (failedUncast.areAllValuesNull()) { + return; + } + BooleanVector failed = ((BooleanBlock) failedUncast).asVector(); + assert failed.getPositionCount() == 1; + if (failed.getBoolean(0)) { + state.failed(true); state.seen(true); + } else if (seen.getBoolean(0)) { + try { + state.longValue(SumLongAggregator.combine(state.longValue(), sum.getLong(0))); + state.seen(true); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.failed(true); + } } } @@ -158,7 +205,7 @@ public void evaluateIntermediate(Block[] blocks, int offset, DriverContext drive @Override public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { - if (state.seen() == false) { + if (state.seen() == false || state.failed()) { blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1); return; } diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java index b4d36aa526075..5183189ee7cfe 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionSupplier.java @@ -9,26 +9,39 @@ import java.lang.String; import java.util.List; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Warnings; /** * {@link AggregatorFunctionSupplier} implementation for {@link SumLongAggregator}. * This class is generated. Do not edit it. */ public final class SumLongAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + int warningsLineNumber; + + int warningsColumnNumber; + + String warningsSourceText; + private final List channels; - public SumLongAggregatorFunctionSupplier(List channels) { + public SumLongAggregatorFunctionSupplier(int warningsLineNumber, int warningsColumnNumber, + String warningsSourceText, List channels) { + this.warningsLineNumber = warningsLineNumber; + this.warningsColumnNumber = warningsColumnNumber; + this.warningsSourceText = warningsSourceText; this.channels = channels; } @Override public SumLongAggregatorFunction aggregator(DriverContext driverContext) { - return SumLongAggregatorFunction.create(driverContext, channels); + var warnings = Warnings.createWarnings(driverContext.warningsMode(), warningsLineNumber, warningsColumnNumber, warningsSourceText); + return SumLongAggregatorFunction.create(warnings, driverContext, channels); } @Override public SumLongGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { - return SumLongGroupingAggregatorFunction.create(channels, driverContext); + var warnings = Warnings.createWarnings(driverContext.warningsMode(), warningsLineNumber, warningsColumnNumber, warningsSourceText); + return SumLongGroupingAggregatorFunction.create(warnings, channels, driverContext); } @Override diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java index c8c0990de4e54..388f157151318 100644 --- a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunction.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.compute.aggregation; +import java.lang.ArithmeticException; import java.lang.Integer; import java.lang.Override; import java.lang.String; @@ -19,6 +20,7 @@ import org.elasticsearch.compute.data.LongVector; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Warnings; /** * {@link GroupingAggregatorFunction} implementation for {@link SumLongAggregator}. @@ -27,24 +29,28 @@ public final class SumLongGroupingAggregatorFunction implements GroupingAggregatorFunction { private static final List INTERMEDIATE_STATE_DESC = List.of( new IntermediateStateDesc("sum", ElementType.LONG), - new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + new IntermediateStateDesc("seen", ElementType.BOOLEAN), + new IntermediateStateDesc("failed", ElementType.BOOLEAN) ); - private final LongArrayState state; + private final LongFallibleArrayState state; + + private final Warnings warnings; private final List channels; private final DriverContext driverContext; - public SumLongGroupingAggregatorFunction(List channels, LongArrayState state, - DriverContext driverContext) { + public SumLongGroupingAggregatorFunction(Warnings warnings, List channels, + LongFallibleArrayState state, DriverContext driverContext) { + this.warnings = warnings; this.channels = channels; this.state = state; this.driverContext = driverContext; } - public static SumLongGroupingAggregatorFunction create(List channels, + public static SumLongGroupingAggregatorFunction create(Warnings warnings, List channels, DriverContext driverContext) { - return new SumLongGroupingAggregatorFunction(channels, new LongArrayState(driverContext.bigArrays(), SumLongAggregator.init()), driverContext); + return new SumLongGroupingAggregatorFunction(warnings, channels, new LongFallibleArrayState(driverContext.bigArrays(), SumLongAggregator.init()), driverContext); } public static List intermediateStateDesc() { @@ -101,13 +107,21 @@ public void close() { private void addRawInput(int positionOffset, IntVector groups, LongBlock values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); + if (state.hasFailed(groupId)) { + continue; + } if (values.isNull(groupPosition + positionOffset)) { continue; } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } @@ -115,7 +129,15 @@ private void addRawInput(int positionOffset, IntVector groups, LongBlock values) private void addRawInput(int positionOffset, IntVector groups, LongVector values) { for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + if (state.hasFailed(groupId)) { + continue; + } + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } @@ -128,13 +150,21 @@ private void addRawInput(int positionOffset, IntBlock groups, LongBlock values) int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); + if (state.hasFailed(groupId)) { + continue; + } if (values.isNull(groupPosition + positionOffset)) { continue; } int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); for (int v = valuesStart; v < valuesEnd; v++) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(v))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } @@ -149,7 +179,15 @@ private void addRawInput(int positionOffset, IntBlock groups, LongVector values) int groupEnd = groupStart + groups.getValueCount(groupPosition); for (int g = groupStart; g < groupEnd; g++) { int groupId = groups.getInt(g); - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + if (state.hasFailed(groupId)) { + continue; + } + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), values.getLong(groupPosition + positionOffset))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } @@ -173,11 +211,23 @@ public void addIntermediateInput(int positionOffset, IntVector groups, Page page return; } BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); - assert sum.getPositionCount() == seen.getPositionCount(); + Block failedUncast = page.getBlock(channels.get(2)); + if (failedUncast.areAllValuesNull()) { + return; + } + BooleanVector failed = ((BooleanBlock) failedUncast).asVector(); + assert sum.getPositionCount() == seen.getPositionCount() && sum.getPositionCount() == failed.getPositionCount(); for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { int groupId = groups.getInt(groupPosition); - if (seen.getBoolean(groupPosition + positionOffset)) { - state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + if (failed.getBoolean(groupPosition + positionOffset)) { + state.setFailed(groupId); + } else if (seen.getBoolean(groupPosition + positionOffset)) { + try { + state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), sum.getLong(groupPosition + positionOffset))); + } catch (ArithmeticException e) { + warnings.registerException(e); + state.setFailed(groupId); + } } } } @@ -187,7 +237,7 @@ public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction inpu if (input.getClass() != getClass()) { throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); } - LongArrayState inState = ((SumLongGroupingAggregatorFunction) input).state; + LongFallibleArrayState inState = ((SumLongGroupingAggregatorFunction) input).state; state.enableGroupIdTracking(new SeenGroupIds.Empty()); if (inState.hasValue(position)) { state.set(groupId, SumLongAggregator.combine(state.getOrDefault(groupId), inState.get(position))); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java new file mode 100644 index 0000000000000..b3956508d813b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregator.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; + +/** + * Sum long aggregator for compatibility with old versions. + *

+ * Replaced by {@link org.elasticsearch.compute.aggregation.SumLongAggregator} since {@code EsqlFeatures#FN_SUM_OVERFLOW_HANDLING}. + *

+ *

+ * Should be kept for as long as we need compatibility with the version this was added on, as the new aggregator's layout is different. + *

+ */ +@Aggregator(value = { @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN") }) +@GroupingAggregator +class OverflowingSumLongAggregator { + + public static long init() { + return 0; + } + + public static long combine(long current, long v) { + return Math.addExact(current, v); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java index cd6a94e518be8..178e4915022e1 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/SumLongAggregator.java @@ -11,7 +11,13 @@ import org.elasticsearch.compute.ann.GroupingAggregator; import org.elasticsearch.compute.ann.IntermediateState; -@Aggregator({ @IntermediateState(name = "sum", type = "LONG"), @IntermediateState(name = "seen", type = "BOOLEAN") }) +@Aggregator( + value = { + @IntermediateState(name = "sum", type = "LONG"), + @IntermediateState(name = "seen", type = "BOOLEAN"), + @IntermediateState(name = "failed", type = "BOOLEAN") }, + warnExceptions = ArithmeticException.class +) @GroupingAggregator class SumLongAggregator { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java new file mode 100644 index 0000000000000..79458fc915ce7 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongAggregatorFunctionTests.java @@ -0,0 +1,86 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.Driver; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.PageConsumerOperator; +import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.TestResultPageSinkOperator; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class OverflowingSumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + long max = randomLongBetween(1, Long.MAX_VALUE / size); + return new SequenceLongBlockSourceOperator(blockFactory, LongStream.range(0, size).map(l -> randomLongBetween(-max, max))); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new OverflowingSumLongAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "overflowing_sum of longs"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + long sum = input.stream().flatMapToLong(b -> allLongs(b)).sum(); + assertThat(((LongBlock) result).getLong(0), equalTo(sum)); + } + + public void testOverflowFails() { + DriverContext driverContext = driverContext(); + + assertThrows(ArithmeticException.class, () -> { + try ( + Driver d = new Driver( + driverContext, + new SequenceLongBlockSourceOperator(driverContext.blockFactory(), LongStream.of(Long.MAX_VALUE - 1, 2)), + List.of(simple().get(driverContext)), + new TestResultPageSinkOperator(r -> {}), + () -> {} + ) + ) { + runDriver(d); + } + }); + + assertDriverContext(driverContext); + } + + public void testRejectsDouble() { + DriverContext driverContext = driverContext(); + BlockFactory blockFactory = driverContext.blockFactory(); + try ( + Driver d = new Driver( + driverContext, + new CannedSourceOperator(Iterators.single(new Page(blockFactory.newDoubleArrayVector(new double[] { 1.0 }, 1).asBlock()))), + List.of(simple().get(driverContext)), + new PageConsumerOperator(page -> fail("shouldn't have made it this far")), + () -> {} + ) + ) { + expectThrows(Exception.class, () -> runDriver(d)); // ### find a more specific exception type + } + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java new file mode 100644 index 0000000000000..24345b35266ff --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/OverflowingSumLongGroupingAggregatorFunctionTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.CannedSourceOperator; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.Operator; +import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.TupleBlockSourceOperator; +import org.elasticsearch.core.Tuple; + +import java.util.List; +import java.util.stream.LongStream; + +import static org.hamcrest.Matchers.equalTo; + +public class OverflowingSumLongGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new OverflowingSumLongAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "overflowing_sum of longs"; + } + + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + long max = randomLongBetween(1, Long.MAX_VALUE / size / 5); + return new TupleBlockSourceOperator( + blockFactory, + LongStream.range(0, size).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), randomLongBetween(-max, max))) + ); + } + + @Override + public void assertSimpleGroup(List input, Block result, int position, Long group) { + long sum = input.stream().flatMapToLong(p -> allLongs(p, group)).sum(); + assertThat(((LongBlock) result).getLong(position), equalTo(sum)); + } + + public void testOverflowFails() { + DriverContext driverContext = driverContext(); + + assertThrows(ArithmeticException.class, () -> { + Operator.OperatorFactory factory = simpleWithMode(AggregatorMode.SINGLE); + List input = CannedSourceOperator.collectPages( + new TupleBlockSourceOperator( + driverContext.blockFactory(), + LongStream.range(0, 10).mapToObj(l -> Tuple.tuple(randomLongBetween(0, 4), Long.MAX_VALUE - 1)) + ) + ); + drive(factory.get(driverContext), input.iterator(), driverContext); + }); + + assertDriverContext(driverContext); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java index 7fd3cabb2c91e..fc0ecbe9e9fd7 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongAggregatorFunctionTests.java @@ -18,10 +18,14 @@ import org.elasticsearch.compute.operator.PageConsumerOperator; import org.elasticsearch.compute.operator.SequenceLongBlockSourceOperator; import org.elasticsearch.compute.operator.SourceOperator; +import org.elasticsearch.compute.operator.TestResultPageSinkOperator; +import java.util.ArrayList; import java.util.List; import java.util.stream.LongStream; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; public class SumLongAggregatorFunctionTests extends AggregatorFunctionTestCase { @@ -33,7 +37,7 @@ protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { @Override protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { - return new SumLongAggregatorFunctionSupplier(inputChannels); + return new SumLongAggregatorFunctionSupplier(-1, -2, "", inputChannels); } @Override @@ -48,19 +52,37 @@ public void assertSimpleOutput(List input, Block result) { } public void testOverflowFails() { + List results = new ArrayList<>(); DriverContext driverContext = driverContext(); + List warnings = new ArrayList<>(); try ( Driver d = new Driver( driverContext, new SequenceLongBlockSourceOperator(driverContext.blockFactory(), LongStream.of(Long.MAX_VALUE - 1, 2)), List.of(simple().get(driverContext)), - new PageConsumerOperator(page -> fail("shouldn't have made it this far")), - () -> {} + new TestResultPageSinkOperator(results::add), + () -> { + warnings.addAll(threadContext.getResponseHeaders().getOrDefault("Warning", List.of())); + } ) ) { - Exception e = expectThrows(ArithmeticException.class, () -> runDriver(d)); - assertThat(e.getMessage(), equalTo("long overflow")); + runDriver(d); } + + assertDriverContext(driverContext); + + assertThat(results.size(), equalTo(1)); + assertThat(results.get(0).getBlockCount(), equalTo(1)); + assertThat(results.get(0).getPositionCount(), equalTo(1)); + assertThat(results.get(0).getBlock(0).isNull(0), equalTo(true)); + + assertThat( + warnings, + contains( + containsString("\"Line -1:-2: evaluation of [] failed, treating result as null. Only first 20 failures recorded.\""), + containsString("\"Line -1:-2: java.lang.ArithmeticException: long overflow\"") + ) + ); } public void testRejectsDouble() { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java index f41a5cbef94fb..6caad2fd06118 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/SumLongGroupingAggregatorFunctionTests.java @@ -23,7 +23,7 @@ public class SumLongGroupingAggregatorFunctionTests extends GroupingAggregatorFunctionTestCase { @Override protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { - return new SumLongAggregatorFunctionSupplier(inputChannels); + return new SumLongAggregatorFunctionSupplier(-1, -2, "", inputChannels); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java index 5a439becd4757..2cc2b7b9ef4d7 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/data/BlockSerializationTests.java @@ -15,7 +15,7 @@ import org.elasticsearch.common.util.BytesRefHash; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.PageCacheRecycler; -import org.elasticsearch.compute.aggregation.SumLongAggregatorFunction; +import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunction; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.core.Releasables; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; @@ -236,7 +236,7 @@ public void testConstantNullBlock() throws IOException { public void testSimulateAggs() { DriverContext driverCtx = driverContext(); Page page = new Page(blockFactory.newLongArrayVector(new long[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, 10).asBlock()); - var function = SumLongAggregatorFunction.create(driverCtx, List.of(0)); + var function = MaxLongAggregatorFunction.create(driverCtx, List.of(0)); try (BooleanVector noMasking = driverContext().blockFactory().newConstantBooleanVector(true, page.getPositionCount())) { function.addRawInput(page, noMasking); } @@ -249,13 +249,13 @@ public void testSimulateAggs() { IntStream.range(0, blocks.length) .forEach(i -> EqualsHashCodeTestUtils.checkEqualsAndHashCode(blocks[i], unused -> deserBlocks[i])); - var inputChannels = IntStream.range(0, SumLongAggregatorFunction.intermediateStateDesc().size()).boxed().toList(); - try (var finalAggregator = SumLongAggregatorFunction.create(driverCtx, inputChannels)) { + var inputChannels = IntStream.range(0, MaxLongAggregatorFunction.intermediateStateDesc().size()).boxed().toList(); + try (var finalAggregator = MaxLongAggregatorFunction.create(driverCtx, inputChannels)) { finalAggregator.addIntermediateInput(new Page(deserBlocks)); Block[] finalBlocks = new Block[1]; finalAggregator.evaluateFinal(finalBlocks, 0, driverCtx); try (var finalBlock = (LongBlock) finalBlocks[0]) { - assertThat(finalBlock.getLong(0), is(55L)); + assertThat(finalBlock.getLong(0), is(10L)); } } } finally { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java index 38d83fe894170..bb21f162383d5 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/AggregationOperatorTests.java @@ -47,7 +47,7 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { return new AggregationOperator.AggregationOperatorFactory( List.of( - new SumLongAggregatorFunctionSupplier(sumChannels).aggregatorFactory(mode), + new SumLongAggregatorFunctionSupplier(-1, -2, "", sumChannels).aggregatorFactory(mode), new MaxLongAggregatorFunctionSupplier(maxChannels).aggregatorFactory(mode) ), mode diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java index f2fa94c1feb08..9d376ed9b304e 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/HashAggregationOperatorTests.java @@ -55,7 +55,7 @@ protected Operator.OperatorFactory simpleWithMode(AggregatorMode mode) { return new HashAggregationOperator.HashAggregationOperatorFactory( List.of(new BlockHash.GroupSpec(0, ElementType.LONG)), List.of( - new SumLongAggregatorFunctionSupplier(sumChannels).groupingAggregatorFactory(mode), + new SumLongAggregatorFunctionSupplier(-1, -2, "", sumChannels).groupingAggregatorFactory(mode), new MaxLongAggregatorFunctionSupplier(maxChannels).groupingAggregatorFactory(mode) ), randomPageSize() diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java index 39e79b33327a9..8d916ebfee01a 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/ConfigurationTestUtils.java @@ -18,14 +18,17 @@ import org.elasticsearch.compute.data.ElementType; import org.elasticsearch.compute.lucene.DataPartitioning; import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.esql.action.ParseTables; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.planner.PlannerUtils; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; import java.util.HashMap; import java.util.Map; +import java.util.stream.Collectors; import static org.apache.lucene.tests.util.LuceneTestCase.random; import static org.apache.lucene.tests.util.LuceneTestCase.randomLocale; @@ -71,7 +74,8 @@ public static Configuration randomConfiguration(String query, Map BiFunction uni(BiFunc return function; } + private static UnaryConfigurationAwareBuilder uniConfig(UnaryConfigurationAwareBuilder function) { + return function; + } + private static BinaryBuilder bi(BinaryBuilder function) { return function; } + private static BinaryConfigurationAwareBuilder biConfig(BinaryConfigurationAwareBuilder function) { + return function; + } + private static TernaryBuilder tri(TernaryBuilder function) { return function; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java index 82c0f9d24899e..7c64431c2a8f4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java @@ -20,6 +20,7 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.List; @@ -28,7 +29,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; -public class Avg extends AggregateFunction implements SurrogateExpression { +public class Avg extends ConfigurationAggregateFunction implements SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new); @FunctionInfo( @@ -45,12 +46,16 @@ public class Avg extends AggregateFunction implements SurrogateExpression { tag = "docsStatsAvgNestedExpression" ) } ) - public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { - this(source, field, Literal.TRUE); + public Avg( + Source source, + @Param(name = "number", type = { "double", "integer", "long" }) Expression field, + Configuration configuration + ) { + this(source, field, Literal.TRUE, configuration); } - public Avg(Source source, Expression field, Expression filter) { - super(source, field, filter, emptyList()); + public Avg(Source source, Expression field, Expression filter, Configuration configuration) { + super(source, field, filter, emptyList(), configuration); } @Override @@ -80,17 +85,17 @@ public DataType dataType() { @Override protected NodeInfo info() { - return NodeInfo.create(this, Avg::new, field(), filter()); + return NodeInfo.create(this, Avg::new, field(), filter(), configuration()); } @Override public Avg replaceChildren(List newChildren) { - return new Avg(source(), newChildren.get(0), newChildren.get(1)); + return new Avg(source(), newChildren.get(0), newChildren.get(1), configuration()); } @Override public Avg withFilter(Expression filter) { - return new Avg(source(), field(), filter); + return new Avg(source(), field(), filter, configuration()); } @Override @@ -100,6 +105,6 @@ public Expression surrogate() { return field().foldable() ? new MvAvg(s, field) - : new Div(s, new Sum(s, field, filter()), new Count(s, field, filter()), dataType()); + : new Div(s, new Sum(s, field, filter(), configuration()), new Count(s, field, filter()), dataType()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java new file mode 100644 index 0000000000000..1693cb61ffc6b --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ConfigurationAggregateFunction.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.session.Configuration; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public abstract class ConfigurationAggregateFunction extends AggregateFunction { + + private final Configuration configuration; + + ConfigurationAggregateFunction(Source source, Expression field, List parameters, Configuration configuration) { + super(source, field, parameters); + this.configuration = configuration; + } + + ConfigurationAggregateFunction( + Source source, + Expression field, + Expression filter, + List parameters, + Configuration configuration + ) { + super(source, field, filter, parameters); + this.configuration = configuration; + } + + ConfigurationAggregateFunction(Source source, Expression field, Configuration configuration) { + super(source, field); + this.configuration = configuration; + } + + ConfigurationAggregateFunction(StreamInput in) throws IOException { + super(in); + this.configuration = ((PlanStreamInput) in).configuration(); + } + + public Configuration configuration() { + return configuration; + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), children(), configuration); + } + + @Override + public boolean equals(Object obj) { + if (super.equals(obj) == false) { + return false; + } + ConfigurationAggregateFunction other = (ConfigurationAggregateFunction) obj; + + return configuration.equals(other.configuration); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java index 37c2abaae1e4e..73e919373ef47 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Sum.java @@ -9,9 +9,14 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.IntermediateStateDesc; +import org.elasticsearch.compute.aggregation.OverflowingSumLongAggregatorFunction; +import org.elasticsearch.compute.aggregation.OverflowingSumLongAggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.OverflowingSumLongGroupingAggregatorFunction; import org.elasticsearch.compute.aggregation.SumDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.SumLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -24,11 +29,17 @@ import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.planner.ToAggregator; +import org.elasticsearch.xpack.esql.planner.ToIntermediateState; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.List; import static java.util.Collections.emptyList; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; import static org.elasticsearch.xpack.esql.core.type.DataType.LONG; import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG; @@ -36,7 +47,7 @@ /** * Sum all values of a field in matching documents. */ -public class Sum extends NumericAggregate implements SurrogateExpression { +public class Sum extends ConfigurationAggregateFunction implements ToAggregator, ToIntermediateState, SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new); @FunctionInfo( @@ -53,12 +64,16 @@ public class Sum extends NumericAggregate implements SurrogateExpression { tag = "docsStatsSumNestedExpression" ) } ) - public Sum(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { - this(source, field, Literal.TRUE); + public Sum( + Source source, + @Param(name = "number", type = { "double", "integer", "long" }) Expression field, + Configuration configuration + ) { + this(source, field, Literal.TRUE, configuration); } - public Sum(Source source, Expression field, Expression filter) { - super(source, field, filter, emptyList()); + public Sum(Source source, Expression field, Expression filter, Configuration configuration) { + super(source, field, filter, emptyList(), configuration); } private Sum(StreamInput in) throws IOException { @@ -72,38 +87,70 @@ public String getWriteableName() { @Override protected NodeInfo info() { - return NodeInfo.create(this, Sum::new, field(), filter()); + return NodeInfo.create(this, Sum::new, field(), filter(), configuration()); } @Override public Sum replaceChildren(List newChildren) { - return new Sum(source(), newChildren.get(0), newChildren.get(1)); + return new Sum(source(), newChildren.get(0), newChildren.get(1), configuration()); } @Override public Sum withFilter(Expression filter) { - return new Sum(source(), field(), filter); + return new Sum(source(), field(), filter, configuration()); } @Override - public DataType dataType() { - DataType dt = field().dataType(); - return dt.isWholeNumber() == false || dt == UNSIGNED_LONG ? DOUBLE : LONG; + protected TypeResolution resolveType() { + return isType( + field(), + dt -> dt.isNumeric() && dt != DataType.UNSIGNED_LONG, + sourceText(), + DEFAULT, + "numeric except unsigned_long or counter types" + ); } @Override - protected AggregatorFunctionSupplier longSupplier(List inputChannels) { - return new SumLongAggregatorFunctionSupplier(inputChannels); + public DataType dataType() { + DataType dt = field().dataType(); + return dt.isWholeNumber() == false || dt == UNSIGNED_LONG ? DOUBLE : LONG; } @Override - protected AggregatorFunctionSupplier intSupplier(List inputChannels) { - return new SumIntAggregatorFunctionSupplier(inputChannels); + public final AggregatorFunctionSupplier supplier(List inputChannels) { + DataType type = field().dataType(); + if (type == DataType.LONG) { + // Old aggregator without overflow handling + if (configuration().clusterHasFeature(EsqlFeatures.FN_SUM_OVERFLOW_HANDLING) == false) { + return new OverflowingSumLongAggregatorFunctionSupplier(inputChannels); + } + var location = source().source(); + return new SumLongAggregatorFunctionSupplier( + location.getLineNumber(), + location.getColumnNumber(), + source().text(), + inputChannels + ); + } + if (type == DataType.INTEGER) { + return new SumIntAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.DOUBLE) { + return new SumDoubleAggregatorFunctionSupplier(inputChannels); + } + throw EsqlIllegalArgumentException.illegalDataType(type); } @Override - protected AggregatorFunctionSupplier doubleSupplier(List inputChannels) { - return new SumDoubleAggregatorFunctionSupplier(inputChannels); + public List intermediateState(boolean grouping) { + DataType type = field().dataType(); + if (type == DataType.LONG && configuration().clusterHasFeature(EsqlFeatures.FN_SUM_OVERFLOW_HANDLING) == false) { + return grouping + ? OverflowingSumLongGroupingAggregatorFunction.intermediateStateDesc() + : OverflowingSumLongAggregatorFunction.intermediateStateDesc(); + } + return null; } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java index dbcc50cea3b9b..584109f8719dd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvg.java @@ -25,6 +25,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.List; @@ -34,7 +35,7 @@ import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType; -public class WeightedAvg extends AggregateFunction implements SurrogateExpression, Validatable { +public class WeightedAvg extends ConfigurationAggregateFunction implements SurrogateExpression, Validatable { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( Expression.class, "WeightedAvg", @@ -54,13 +55,14 @@ public class WeightedAvg extends AggregateFunction implements SurrogateExpressio public WeightedAvg( Source source, @Param(name = "number", type = { "double", "integer", "long" }, description = "A numeric value.") Expression field, - @Param(name = "weight", type = { "double", "integer", "long" }, description = "A numeric weight.") Expression weight + @Param(name = "weight", type = { "double", "integer", "long" }, description = "A numeric weight.") Expression weight, + Configuration configuration ) { - this(source, field, Literal.TRUE, weight); + this(source, field, Literal.TRUE, weight, configuration); } - public WeightedAvg(Source source, Expression field, Expression filter, Expression weight) { - super(source, field, filter, List.of(weight)); + public WeightedAvg(Source source, Expression field, Expression filter, Expression weight, Configuration configuration) { + super(source, field, filter, List.of(weight), configuration); this.weight = weight; } @@ -73,7 +75,8 @@ private WeightedAvg(StreamInput in) throws IOException { : Literal.TRUE, in.getTransportVersion().onOrAfter(TransportVersions.ESQL_PER_AGGREGATE_FILTER) ? in.readNamedWriteableCollectionAsList(Expression.class).get(0) - : in.readNamedWriteable(Expression.class) + : in.readNamedWriteable(Expression.class), + ((PlanStreamInput) in).configuration() ); } @@ -132,17 +135,17 @@ public DataType dataType() { @Override protected NodeInfo info() { - return NodeInfo.create(this, WeightedAvg::new, field(), filter(), weight); + return NodeInfo.create(this, WeightedAvg::new, field(), filter(), weight, configuration()); } @Override public WeightedAvg replaceChildren(List newChildren) { - return new WeightedAvg(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2)); + return new WeightedAvg(source(), newChildren.get(0), newChildren.get(1), newChildren.get(2), configuration()); } @Override public WeightedAvg withFilter(Expression filter) { - return new WeightedAvg(source(), field(), filter, weight()); + return new WeightedAvg(source(), field(), filter, weight(), configuration()); } @Override @@ -155,9 +158,9 @@ public Expression surrogate() { return new MvAvg(s, field); } if (weight.foldable()) { - return new Div(s, new Sum(s, field), new Count(s, field), dataType()); + return new Div(s, new Sum(s, field, configuration()), new Count(s, field), dataType()); } else { - return new Div(s, new Sum(s, new Mul(s, field, weight)), new Sum(s, weight), dataType()); + return new Div(s, new Sum(s, new Mul(s, field, weight), configuration()), new Sum(s, weight, configuration()), dataType()); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 3e81c2a2c1101..fef9953ad2059 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -131,11 +131,16 @@ private Stream map(Expression aggregate, boolean grouping) { } private static List computeEntryForAgg(Expression aggregate, boolean grouping) { + if (aggregate instanceof ToIntermediateState intermediateStateGetter) { + var intermediateStates = intermediateStateGetter.intermediateState(grouping); + if (intermediateStates != null) { + return isToNE(intermediateStates).toList(); + } + } var aggDef = aggDefOrNull(aggregate, grouping); if (aggDef != null) { - var is = getNonNull(aggDef); - var exp = isToNE(is).toList(); - return exp; + var intermediateStates = getNonNull(aggDef); + return isToNE(intermediateStates).toList(); } if (aggregate instanceof FieldAttribute || aggregate instanceof MetadataAttribute || aggregate instanceof ReferenceAttribute) { // This condition is a little pedantic, but do we expected other expressions here? if so, then add them @@ -157,7 +162,7 @@ private static List getNonNull(AggDef aggDef) { private static Stream, Tuple>> typeAndNames(Class clazz) { List types; List extraConfigs = List.of(""); - if (NumericAggregate.class.isAssignableFrom(clazz)) { + if (NumericAggregate.class.isAssignableFrom(clazz) || Sum.class.isAssignableFrom(clazz)) { types = NUMERIC; } else if (Max.class.isAssignableFrom(clazz) || Min.class.isAssignableFrom(clazz)) { types = List.of("Boolean", "Int", "Long", "Double", "Ip", "BytesRef"); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java new file mode 100644 index 0000000000000..5591ed439f386 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/ToIntermediateState.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.planner; + +import org.elasticsearch.compute.aggregation.IntermediateStateDesc; + +import java.util.List; + +/** + * Expressions that have a mapping to {@link org.elasticsearch.compute.aggregation.IntermediateStateDesc}s. + */ +public interface ToIntermediateState { + /** + * Returns the intermediate state descriptions for this expression. + *

+ * If null, the default method of {@link AggregateMapper} will be used to get them. + *

+ */ + default List intermediateState(boolean grouping) { + return null; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java index 266f07d22eaf5..8db74635794bb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlFeatures.java @@ -178,6 +178,11 @@ public class EsqlFeatures implements FeatureSpecification { */ public static final NodeFeature RESOLVE_FIELDS_API = new NodeFeature("esql.resolve_fields_api"); + /** + * Overflow handling for Sum aggregation function + */ + public static final NodeFeature FN_SUM_OVERFLOW_HANDLING = new NodeFeature("esql.fn_sum_overflow_handling"); + private Set snapshotBuildFeatures() { assert Build.current().isSnapshot() : Build.current(); return Set.of(METRICS_SYNTAX); @@ -207,7 +212,8 @@ public Set getFeatures() { METADATA_FIELDS, TIMESPAN_ABBREVIATIONS, COUNTER_TYPES, - RESOLVE_FIELDS_API + RESOLVE_FIELDS_API, + FN_SUM_OVERFLOW_HANDLING ); if (Build.current().isSnapshot()) { return Collections.unmodifiableSet(Sets.union(features, snapshotBuildFeatures())); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 04e5fdc4b3bd2..85727935a348d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.operator.exchange.ExchangeService; +import org.elasticsearch.features.FeatureService; import org.elasticsearch.injection.guice.Inject; import org.elasticsearch.search.SearchService; import org.elasticsearch.tasks.CancellableTask; @@ -62,6 +63,7 @@ public class TransportEsqlQueryAction extends HandledTransportAction> tables; private final long queryStartTimeNanos; + private final Set activeEsqlFeatures; + public Configuration( ZoneId zi, Locale locale, @@ -64,7 +72,8 @@ public Configuration( String query, boolean profile, Map> tables, - long queryStartTimeNanos + long queryStartTimeNanos, + Set activeEsqlFeatures ) { this.zoneId = zi.normalized(); this.now = ZonedDateTime.now(Clock.tick(Clock.system(zoneId), Duration.ofNanos(1))); @@ -79,6 +88,7 @@ public Configuration( this.tables = tables; assert tables != null; this.queryStartTimeNanos = queryStartTimeNanos; + this.activeEsqlFeatures = activeEsqlFeatures; } public Configuration(BlockStreamInput in) throws IOException { @@ -106,6 +116,19 @@ public Configuration(BlockStreamInput in) throws IOException { } else { this.queryStartTimeNanos = -1; } + if (in.getTransportVersion().onOrAfter(TransportVersions.ESQL_CONFIGURATION_WITH_FEATURES)) { + this.activeEsqlFeatures = in.readCollectionAsImmutableSet(StreamInput::readString); + } else { + this.activeEsqlFeatures = Set.of(); + } + } + + public static Set calculateActiveClusterFeatures(FeatureService featureService, ClusterService clusterService) { + return new EsqlFeatures().getFeatures() + .stream() + .filter(f -> featureService.clusterHasFeature(clusterService.state(), f)) + .map(NodeFeature::id) + .collect(Collectors.toSet()); } @Override @@ -130,6 +153,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CCS_EXECUTION_INFO)) { out.writeLong(queryStartTimeNanos); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ESQL_CONFIGURATION_WITH_FEATURES)) { + out.writeCollection(activeEsqlFeatures, StreamOutput::writeString); + } } public ZoneId zoneId() { @@ -198,6 +224,14 @@ public boolean profile() { return profile; } + public Set activeEsqlFeatures() { + return activeEsqlFeatures; + } + + public boolean clusterHasFeature(NodeFeature feature) { + return activeEsqlFeatures.contains(feature.id()); + } + private static void writeQuery(StreamOutput out, String query) throws IOException { if (query.length() > QUERY_COMPRESS_THRESHOLD_CHARS) { // compare on chars to avoid UTF-8 encoding unless actually required out.writeBoolean(true); @@ -236,7 +270,8 @@ public boolean equals(Object o) { && Objects.equals(locale, that.locale) && Objects.equals(that.query, query) && profile == that.profile - && tables.equals(that.tables); + && tables.equals(that.tables) + && activeEsqlFeatures.equals(that.activeEsqlFeatures); } @Override @@ -252,7 +287,8 @@ public int hashCode() { locale, query, profile, - tables + tables, + activeEsqlFeatures ); } @@ -274,6 +310,8 @@ public String toString() { + profile + ", tables=" + tables + + ", activeEsqlFeatures=" + + activeEsqlFeatures + '}'; } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index db5d8e03458ea..900f7e3f44119 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -193,9 +193,9 @@ private void aggregateGroupingSingleMode(Expression expression) { assert testCase.getMatcher().matches(Double.NEGATIVE_INFINITY) == false; assertThat(result, not(equalTo(Double.NEGATIVE_INFINITY))); assertThat(result, testCase.getMatcher()); - if (testCase.getExpectedWarnings() != null) { - assertWarnings(testCase.getExpectedWarnings()); - } + } + if (testCase.getExpectedWarnings() != null) { + assertWarnings(testCase.getExpectedWarnings()); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractConfigurationAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractConfigurationAggregationTestCase.java new file mode 100644 index 0000000000000..4074cbd993c59 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractConfigurationAggregationTestCase.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.xpack.esql.EsqlTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.util.StringUtils; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; +import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; +import org.elasticsearch.xpack.esql.plugin.QueryPragmas; +import org.elasticsearch.xpack.esql.session.Configuration; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; + +public abstract class AbstractConfigurationAggregationTestCase extends AbstractAggregationTestCase { + protected abstract Expression buildWithConfiguration(Source source, List args, Configuration configuration); + + @Override + protected Expression build(Source source, List args) { + return buildWithConfiguration(source, args, EsqlTestUtils.TEST_CFG); + } + + static Configuration randomConfiguration() { + // TODO: Randomize the query and maybe the pragmas. + return new Configuration( + randomZone(), + randomLocale(random()), + randomBoolean() ? null : randomAlphaOfLength(randomInt(64)), + randomBoolean() ? null : randomAlphaOfLength(randomInt(64)), + QueryPragmas.EMPTY, + EsqlPlugin.QUERY_RESULT_TRUNCATION_MAX_SIZE.getDefault(Settings.EMPTY), + EsqlPlugin.QUERY_RESULT_TRUNCATION_DEFAULT_SIZE.getDefault(Settings.EMPTY), + StringUtils.EMPTY, + randomBoolean(), + Map.of(), + System.nanoTime(), + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) + ); + } + + public void testSerializationWithConfiguration() { + Configuration config = randomConfiguration(); + Expression expr = buildWithConfiguration(testCase.getSource(), testCase.getDataAsFields(), config); + + assertSerialization(expr, config); + + Configuration differentConfig; + do { + differentConfig = randomConfiguration(); + } while (config.equals(differentConfig)); + + Expression differentExpr = buildWithConfiguration(testCase.getSource(), testCase.getDataAsFields(), differentConfig); + assertFalse(expr.equals(differentExpr)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index c12e0a8684ba9..71ad54daa3166 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -1538,6 +1538,31 @@ public TestCase withExtra(Object extra) { ); } + public TestCase withWarnings(List warnings) { + String[] newWarnings; + if (expectedWarnings != null) { + newWarnings = Arrays.copyOf(expectedWarnings, expectedWarnings.length + warnings.size()); + for (int i = 0; i < warnings.size(); i++) { + newWarnings[expectedWarnings.length + i] = warnings.get(i); + } + } else { + newWarnings = warnings.toArray(String[]::new); + } + + return new TestCase( + data, + evaluatorToString, + expectedType, + matcher, + newWarnings, + expectedBuildEvaluatorWarnings, + expectedTypeError, + foldingExceptionClass, + foldingExceptionMessage, + extra + ); + } + public TestCase withWarning(String warning) { return new TestCase( data, diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java index 52d3128af5c1c..d010930177881 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgSerializationTests.java @@ -14,12 +14,16 @@ public class AvgSerializationTests extends AbstractExpressionSerializationTests { @Override protected Avg createTestInstance() { - return new Avg(randomSource(), randomChild()); + return new Avg(randomSource(), randomChild(), configuration()); } @Override protected Avg mutateInstance(Avg instance) throws IOException { - return new Avg(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + return new Avg( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), + instance.configuration() + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java index ac599c7ff05f8..b618066a45726 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java @@ -13,9 +13,10 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; import java.util.List; @@ -25,7 +26,7 @@ import static org.hamcrest.Matchers.equalTo; -public class AvgTests extends AbstractAggregationTestCase { +public class AvgTests extends AbstractConfigurationAggregationTestCase { public AvgTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -57,8 +58,8 @@ public static Iterable parameters() { } @Override - protected Expression build(Source source, List args) { - return new Avg(source, args.get(0)); + protected Expression buildWithConfiguration(Source source, List args, Configuration configuration) { + return new Avg(source, args.get(0), configuration); } private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java index 863392f7eb451..5b50c748f9a4a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumSerializationTests.java @@ -14,12 +14,16 @@ public class SumSerializationTests extends AbstractExpressionSerializationTests { @Override protected Sum createTestInstance() { - return new Sum(randomSource(), randomChild()); + return new Sum(randomSource(), randomChild(), configuration()); } @Override protected Sum mutateInstance(Sum instance) throws IOException { - return new Sum(instance.source(), randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild)); + return new Sum( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), + instance.configuration() + ); } @Override diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java index 4f14dafc8b30d..8e84f12b1d14c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/SumTests.java @@ -10,12 +10,14 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.elasticsearch.search.aggregations.metrics.CompensatedSum; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; import java.util.List; @@ -26,7 +28,7 @@ import static org.elasticsearch.xpack.esql.core.type.DataType.UNSIGNED_LONG; import static org.hamcrest.Matchers.equalTo; -public class SumTests extends AbstractAggregationTestCase { +public class SumTests extends AbstractConfigurationAggregationTestCase { public SumTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -35,13 +37,12 @@ public SumTests(@Name("TestCase") Supplier testCaseSu public static Iterable parameters() { var suppliers = new ArrayList(); - Stream.of(MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true) - // Longs currently throw on overflow. - // Restore after https://github.com/elastic/elasticsearch/issues/110437 - // MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), - // Doubles currently return +/-Infinity on overflow. - // Restore after https://github.com/elastic/elasticsearch/issues/111026 - // MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true) + // Doubles currently return +/-Infinity on overflow. + // Restore after https://github.com/elastic/elasticsearch/issues/111026 + // MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) ).flatMap(List::stream).map(SumTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); suppliers.addAll( @@ -81,52 +82,45 @@ public static Iterable parameters() { } @Override - protected Expression build(Source source, List args) { - return new Sum(source, args.get(0)); + protected Expression buildWithConfiguration(Source source, List args, Configuration configuration) { + return new Sum(source, args.get(0), configuration); } private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> { var fieldTypedData = fieldSupplier.get(); + var expectedWarnings = new ArrayList(); Object expected; try { expected = switch (fieldTypedData.type().widenSmallNumeric()) { case INTEGER -> fieldTypedData.multiRowData() .stream() - .map(v -> (Integer) v) - .collect(Collectors.summarizingInt(Integer::intValue)) - .getSum(); - case LONG -> fieldTypedData.multiRowData() - .stream() - .map(v -> (Long) v) - .collect(Collectors.summarizingLong(Long::longValue)) - .getSum(); + .map(v -> ((Integer) v).longValue()) + .reduce(Math::addExact) + .orElse(null); + case LONG -> fieldTypedData.multiRowData().stream().map(v -> (Long) v).reduce(Math::addExact).orElse(null); case DOUBLE -> { - var value = fieldTypedData.multiRowData() - .stream() - .map(v -> (Double) v) - .collect(Collectors.summarizingDouble(Double::doubleValue)) - .getSum(); - - if (Double.isInfinite(value) || Double.isNaN(value)) { - yield null; - } + var sum = new CompensatedSum(); + fieldTypedData.multiRowData().stream().map(v -> (Double) v).forEach(sum::add); - yield value; + yield Double.isFinite(sum.value()) ? sum.value() : null; } default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type()); }; - } catch (Exception e) { + } catch (ArithmeticException e) { expected = null; + expectedWarnings.add("Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded."); + expectedWarnings.add("Line -1:-1: java.lang.ArithmeticException: long overflow"); } var dataType = fieldTypedData.type().isWholeNumber() == false || fieldTypedData.type() == UNSIGNED_LONG ? DataType.DOUBLE : DataType.LONG; - return new TestCaseSupplier.TestCase(List.of(fieldTypedData), "Sum[field=Attribute[channel=0]]", dataType, equalTo(expected)); + return new TestCaseSupplier.TestCase(List.of(fieldTypedData), "Sum[field=Attribute[channel=0]]", dataType, equalTo(expected)) + .withWarnings(expectedWarnings); }); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java new file mode 100644 index 0000000000000..598bf697c6f33 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgSerializationTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; + +public class WeightedAvgSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected WeightedAvg createTestInstance() { + return new WeightedAvg(randomSource(), randomChild(), randomChild(), configuration()); + } + + @Override + protected WeightedAvg mutateInstance(WeightedAvg instance) throws IOException { + return new WeightedAvg( + instance.source(), + randomValueOtherThan(instance.field(), AbstractExpressionSerializationTests::randomChild), + randomValueOtherThan(instance.weight(), AbstractExpressionSerializationTests::randomChild), + instance.configuration() + ); + } + + @Override + protected boolean alwaysEmptySource() { + return true; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java index 2c2ffc97f268c..7a30d6f499ee1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/WeightedAvgTests.java @@ -13,9 +13,10 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.AbstractConfigurationAggregationTestCase; import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.elasticsearch.xpack.esql.session.Configuration; import java.util.ArrayList; import java.util.List; @@ -25,7 +26,7 @@ import static org.hamcrest.Matchers.equalTo; -public class WeightedAvgTests extends AbstractAggregationTestCase { +public class WeightedAvgTests extends AbstractConfigurationAggregationTestCase { public WeightedAvgTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @@ -94,8 +95,8 @@ public static Iterable parameters() { } @Override - protected Expression build(Source source, List args) { - return new WeightedAvg(source, args.get(0), args.get(1)); + protected Expression buildWithConfiguration(Source source, List args, Configuration configuration) { + return new WeightedAvg(source, args.get(0), args.get(1), configuration); } private static TestCaseSupplier makeSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java index a3a18d7a30b59..852fe0daae111 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/AbstractConfigurationFunctionTestCase.java @@ -8,17 +8,20 @@ package org.elasticsearch.xpack.esql.expression.function.scalar; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.StringUtils; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.esql.SerializationTestUtils.assertSerialization; @@ -43,7 +46,8 @@ static Configuration randomConfiguration() { StringUtils.EMPTY, randomBoolean(), Map.of(), - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java index 69dbe023bde66..86714bffdb0a7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToLowerTests.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; @@ -71,7 +72,8 @@ private Configuration randomLocaleConfig() { "", false, Map.of(), - System.nanoTime() + System.nanoTime(), + Set.of() ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java index 33d6f929503b3..bbaef61bb5fc2 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/string/ToUpperTests.java @@ -28,6 +28,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.Supplier; import static org.hamcrest.Matchers.equalTo; @@ -71,7 +72,8 @@ private Configuration randomLocaleConfig() { "", false, Map.of(), - System.nanoTime() + System.nanoTime(), + Set.of() ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index 9f5d6440e4a06..886917c471417 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -217,7 +217,7 @@ public PhysicalPlanOptimizerTests(String name, Configuration config) { @Before public void init() { parser = new EsqlParser(); - logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(EsqlTestUtils.TEST_CFG)); + logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(config)); physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(config)); EsqlFunctionRegistry functionRegistry = new EsqlFunctionRegistry(); mapper = new Mapper(); @@ -6685,7 +6685,7 @@ private PhysicalPlan physicalPlan(String query, TestDataSource dataSource) { // System.out.println("Logical\n" + logical); var physical = mapper.map(logical); // System.out.println(physical); - assertSerialization(physical); + assertSerialization(physical, config); return physical; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java index 89117b5d4e729..5b43948a0bed1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNullTests.java @@ -65,6 +65,7 @@ import java.util.List; import static org.elasticsearch.xpack.esql.EsqlTestUtils.L; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.configuration; import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; @@ -191,9 +192,9 @@ public void testNullFoldingDoesNotApplyOnAggregate() throws Exception { assertEquals(NULL, rule.rule(conditionalFunction)); } - Avg avg = new Avg(EMPTY, getFieldAttribute("a")); + Avg avg = new Avg(EMPTY, getFieldAttribute("a"), TEST_CFG); assertEquals(avg, rule.rule(avg)); - avg = new Avg(EMPTY, NULL); + avg = new Avg(EMPTY, NULL, TEST_CFG); assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(avg)); Count count = new Count(EMPTY, getFieldAttribute("a")); @@ -221,9 +222,9 @@ public void testNullFoldingDoesNotApplyOnAggregate() throws Exception { percentile = new Percentile(EMPTY, NULL, NULL); assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(percentile)); - Sum sum = new Sum(EMPTY, getFieldAttribute("a")); + Sum sum = new Sum(EMPTY, getFieldAttribute("a"), TEST_CFG); assertEquals(sum, rule.rule(sum)); - sum = new Sum(EMPTY, NULL); + sum = new Sum(EMPTY, NULL, TEST_CFG); assertEquals(new Literal(EMPTY, null, DOUBLE), rule.rule(sum)); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java index e6faa9a253d76..f83df6440c5cf 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/AbstractNodeSerializationTests.java @@ -35,7 +35,7 @@ public abstract class AbstractNodeSerializationTests> * We use a single random config for all serialization because it's pretty * heavy to build, especially in {@link #testConcurrentSerialization()}. */ - private Configuration config; + private static Configuration config; public static Source randomSource() { int lineNumber = between(0, EXAMPLE_QUERY.length - 1); @@ -77,7 +77,7 @@ protected boolean alwaysEmptySource() { return false; } - public final Configuration configuration() { + public static final Configuration configuration() { return config; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java index 01f797491103c..34d32477ab610 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/logical/AggregateSerializationTests.java @@ -51,7 +51,7 @@ public static List randomAggregates() { new Literal(randomSource(), randomFrom("ASC", "DESC"), DataType.KEYWORD) ); case 4 -> new Values(randomSource(), FieldAttributeTests.createFieldAttribute(1, true)); - case 5 -> new Sum(randomSource(), FieldAttributeTests.createFieldAttribute(1, true)); + case 5 -> new Sum(randomSource(), FieldAttributeTests.createFieldAttribute(1, true), configuration()); default -> throw new IllegalArgumentException(); }; result.add(new Alias(randomSource(), randomAlphaOfLength(5), agg)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java index 0e09809d16902..896ee3cad08c9 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/EvalMapperTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.SerializationTestUtils; @@ -49,6 +50,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.session.Configuration; import java.time.Duration; @@ -58,6 +60,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; public class EvalMapperTests extends ESTestCase { private static final FieldAttribute DOUBLE1 = field("foo", DataType.DOUBLE); @@ -76,7 +79,8 @@ public class EvalMapperTests extends ESTestCase { StringUtils.EMPTY, false, Map.of(), - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); @ParametersFactory(argumentFormatting = "%1$s") diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java index f60e5384e1a6f..65c31f462afb0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/planner/LocalExecutionPlannerTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.IndexMode; import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy; import org.elasticsearch.index.mapper.MapperServiceTestCase; @@ -40,6 +41,7 @@ import org.elasticsearch.xpack.esql.expression.Order; import org.elasticsearch.xpack.esql.index.EsIndex; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.Configuration; @@ -50,6 +52,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.lessThanOrEqualTo; @@ -162,7 +165,8 @@ private Configuration config() { StringUtils.EMPTY, false, Map.of(), - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java index 4553551c40cd3..19be4581ded73 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java @@ -18,7 +18,6 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.core.type.EsField; @@ -33,6 +32,7 @@ import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.planner.mapper.Mapper; +import org.elasticsearch.xpack.esql.session.Configuration; import java.io.IOException; import java.util.ArrayList; @@ -41,7 +41,6 @@ import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_CFG; import static org.elasticsearch.xpack.esql.EsqlTestUtils.TEST_VERIFIER; import static org.elasticsearch.xpack.esql.EsqlTestUtils.emptyPolicyResolution; import static org.elasticsearch.xpack.esql.EsqlTestUtils.loadMapping; @@ -79,14 +78,15 @@ protected DataNodeRequest createTestInstance() { | stats x = avg(salary) """); List shardIds = randomList(1, 10, () -> new ShardId("index-" + between(1, 10), "n/a", between(1, 10))); - PhysicalPlan physicalPlan = mapAndMaybeOptimize(parse(query)); + Configuration configuration = randomLimitedConfiguration(query); + PhysicalPlan physicalPlan = mapAndMaybeOptimize(parse(query, configuration), configuration); Map aliasFilters = Map.of( new Index("concrete-index", "n/a"), AliasFilter.of(new TermQueryBuilder("id", "1"), "alias-1") ); DataNodeRequest request = new DataNodeRequest( sessionId, - randomConfiguration(query, randomTables()), + configuration, randomAlphaOfLength(10), shardIds, aliasFilters, @@ -164,7 +164,7 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException in.clusterAlias(), in.shardIds(), in.aliasFilters(), - mapAndMaybeOptimize(parse(newQuery)), + mapAndMaybeOptimize(parse(newQuery, in.configuration()), in.configuration()), in.indices(), in.indicesOptions() ); @@ -260,20 +260,20 @@ protected DataNodeRequest mutateInstance(DataNodeRequest in) throws IOException }; } - static LogicalPlan parse(String query) { + static LogicalPlan parse(String query, Configuration configuration) { Map mapping = loadMapping("mapping-basic.json"); EsIndex test = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD)); IndexResolution getIndexResult = IndexResolution.valid(test); - var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(TEST_CFG)); + var logicalOptimizer = new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)); var analyzer = new Analyzer( - new AnalyzerContext(EsqlTestUtils.TEST_CFG, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), + new AnalyzerContext(configuration, new EsqlFunctionRegistry(), getIndexResult, emptyPolicyResolution()), TEST_VERIFIER ); return logicalOptimizer.optimize(analyzer.analyze(new EsqlParser().createStatement(query))); } - static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan) { - var physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(TEST_CFG)); + static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan, Configuration configuration) { + var physicalPlanOptimizer = new PhysicalPlanOptimizer(new PhysicalOptimizerContext(configuration)); var mapper = new Mapper(); var physical = mapper.map(logicalPlan); if (randomBoolean()) { @@ -286,4 +286,29 @@ static PhysicalPlan mapAndMaybeOptimize(LogicalPlan logicalPlan) { protected List filteredWarnings() { return withDefaultLimitWarning(super.filteredWarnings()); } + + /** + * Builds a configuration with a fixed limit of 10000 rows and 1000 bytes. + *

+ * Without this, warnings would be randomized, and hard to filter from the warnings. + *

+ */ + private Configuration randomLimitedConfiguration(String query) { + Configuration configuration = randomConfiguration(query, randomTables()); + + return new Configuration( + configuration.zoneId(), + configuration.locale(), + configuration.username(), + configuration.clusterName(), + configuration.pragmas(), + 10000, + 1000, + configuration.query(), + configuration.profile(), + configuration.tables(), + configuration.getQueryStartTimeNanos(), + configuration.activeEsqlFeatures() + ); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java index 1f35bb5312b20..a72b15faa80cd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/session/ConfigurationSerializationTests.java @@ -18,13 +18,16 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockStreamInput; import org.elasticsearch.core.Releasables; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xpack.esql.Column; +import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import java.time.ZoneId; import java.util.Locale; import java.util.Map; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomConfiguration; import static org.elasticsearch.xpack.esql.ConfigurationTestUtils.randomTables; @@ -103,7 +106,8 @@ protected Configuration mutateInstance(Configuration in) { query, profile, tables, - System.nanoTime() + System.nanoTime(), + new EsqlFeatures().getFeatures().stream().map(NodeFeature::id).collect(Collectors.toSet()) ); }